From 52e476d25b09833e0cba75385c0bfc42638e4f5b Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 10 May 2019 19:30:16 +0530 Subject: [PATCH 01/32] add first compiler pass --- src/compiler.jl | 60 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/compiler.jl diff --git a/src/compiler.jl b/src/compiler.jl new file mode 100644 index 00000000..5b45961c --- /dev/null +++ b/src/compiler.jl @@ -0,0 +1,60 @@ +using IRTools: isexpr, IR, Variable, @dynamo, @code_ir, prewalk +using IRTools: xcall, prewalk, postwalk, insertafter!, exprtype +using IRTools: Argument, meta, arg, NewVariable +using IRTools: Pipe, stmt, var, finish +using IRTools +import Adapt, GPUArrays + +# [4] TODO: clean iscuable fn + +function iscuable(x) + x isa NewVariable && return false + x isa Variable && return true + x isa Argument && return true + x isa Symbol && return false + x isa QuoteNode && return false + x isa GlobalRef && x.mod ∈ (CuArrays, Core, Adapt, GPUArrays) && return false + + x isa Expr && x.args[1] isa GlobalRef && x.args[1].mod ∈ (CuArrays, Core, Adapt, GPUArrays) && return false + x isa Expr && x.args[1] ∈ (:Base, :CuArrays, :Core, :Adapt) && return false + x isa Expr && x.args[1] == :cu && return false + x isa Expr && length(x.args) == 1 && return false + x isa Expr && x.args[1] isa Variable && return false + x isa Expr && x.args[2] isa QuoteNode && return true + + x isa GlobalRef && x.mod ∉ (Base, Core) || return true + isconst(x.mod, x.name) || return true + x = getfield(x.mod, x.name) + !(x isa Type || sizeof(x) == 0) +end + + +function traverse_and_insert!(ir, v, ex) + ir[v] = postwalk(ex) do x + iscuable(x) || return x + insert!(ir, v, stmt(Expr(:call, GlobalRef(CuArrays, :cu), x))) + end +end + +∉(a,b) = !in(a,b) + +@dynamo function cuize(meta) + meta == nothing && return + ir = IR(meta) + ir == nothing && return + pr = Pipe(ir) + pr == nothing && return + + for (v,st) in pr + ex = st.expr + ex = traverse_and_insert!(pr, v, ex) + + # Comment this to make it work + # pr[v] = Expr(:call, GlobalRef(Main, :cuize), st.expr.args...) + end + + return finish(pr) +end + + + From 036a1709a1a99cdf8bb3f0bc32242728b29673dc Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 5 Jun 2019 05:22:00 +0530 Subject: [PATCH 02/32] add compiler pass --- src/compiler.jl | 125 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 85 insertions(+), 40 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5b45961c..a2164e94 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,60 +1,105 @@ using IRTools: isexpr, IR, Variable, @dynamo, @code_ir, prewalk using IRTools: xcall, prewalk, postwalk, insertafter!, exprtype -using IRTools: Argument, meta, arg, NewVariable +using IRTools: meta, NewVariable using IRTools: Pipe, stmt, var, finish using IRTools import Adapt, GPUArrays # [4] TODO: clean iscuable fn -function iscuable(x) - x isa NewVariable && return false - x isa Variable && return true - x isa Argument && return true - x isa Symbol && return false - x isa QuoteNode && return false - x isa GlobalRef && x.mod ∈ (CuArrays, Core, Adapt, GPUArrays) && return false - - x isa Expr && x.args[1] isa GlobalRef && x.args[1].mod ∈ (CuArrays, Core, Adapt, GPUArrays) && return false - x isa Expr && x.args[1] ∈ (:Base, :CuArrays, :Core, :Adapt) && return false - x isa Expr && x.args[1] == :cu && return false - x isa Expr && length(x.args) == 1 && return false - x isa Expr && x.args[1] isa Variable && return false - x isa Expr && x.args[2] isa QuoteNode && return true - - x isa GlobalRef && x.mod ∉ (Base, Core) || return true - isconst(x.mod, x.name) || return true - x = getfield(x.mod, x.name) - !(x isa Type || sizeof(x) == 0) -end +# function iscuable(x) +# x isa NewVariable && return false +# x isa Variable && return true +# # x isa Argument && return true +# x isa Symbol && return false +# x isa QuoteNode && return false +# x isa GlobalRef && x.mod ∈ (CuArrays, Core, Adapt, GPUArrays) && return false +# x isa Expr && x.args[1] isa GlobalRef && x.args[1].mod ∈ (CuArrays, Core, Adapt, GPUArrays) && return false +# x isa Expr && x.args[1] ∈ (:Base, :CuArrays, :Core, :Adapt) && return false +# x isa Expr && x.args[1] == :cu && return false +# x isa Expr && length(x.args) == 1 && return false +# x isa Expr && x.args[1] isa Variable && return false +# x isa Expr && x.args[2] isa QuoteNode && return true -function traverse_and_insert!(ir, v, ex) - ir[v] = postwalk(ex) do x - iscuable(x) || return x - insert!(ir, v, stmt(Expr(:call, GlobalRef(CuArrays, :cu), x))) - end +# x isa GlobalRef && x.mod ∉ (Base, Core) || return true +# isconst(x.mod, x.name) || return true +# x = getfield(x.mod, x.name) +# !(x isa Type || sizeof(x) == 0) +# end + + +# function traverse_and_insert!(ir, v, ex) +# ir[v] = postwalk(ex) do x +# iscuable(x) || return x +# insert!(ir, v, stmt(Expr(:call, GlobalRef(CuArrays, :cu), x))) +# end +# end + +# ∉(a,b) = !in(a,b) + +# @dynamo function cuize(meta) +# meta == nothing && return +# ir = IR(meta) +# ir == nothing && return +# pr = Pipe(ir) +# pr == nothing && return + +# for (v,st) in pr +# ex = st.expr +# ex = traverse_and_insert!(pr, v, ex) + +# # Comment this to make it work +# # pr[v] = Expr(:call, GlobalRef(Main, :cuize), st.expr.args...) +# end + +# return finish(pr) +# end + + +# function travserse(pr, x, st) +# pr[x] = postwalk(st) do x +# @show x +# end +# end + +import Base.Broadcast.broadcasted +import Base.Broadcast.materialize + +# using `Array` instead of `cpu` here works but +# causes tracking issues with TrackedArray +cuize(::typeof(*), a, b) = cpu(cu(a) * cu(b)) +cuize(::typeof(+), a, b) = cpu(cu(a) + cu(b)) +cuize(::typeof(-), a, b) = cpu(cu(a) - cu(b)) +cuize(::typeof(/), a, b) = cpu(cu(a) / cu(b)) +cuize(::typeof(materialize), bc) = materialize(bc) + +function cuize(::typeof(broadcasted), f, a...) + b = map(cu, a) + broadcasted(f, b...) end -∉(a,b) = !in(a,b) +# @dynamo function cuize(meta) +# meta == nothing && return +# ir = IR(meta) +# for (x,st) in ir +# isexpr(st.expr, :call) || continue +# ex = st.expr +# ex = travserse(ir, x, ex) +# ir[x] = xcall(Main, :cuize, st.expr.args...) +# end +# ir +# end -@dynamo function cuize(meta) +@dynamo function cuize(meta...) meta == nothing && return - ir = IR(meta) + ir = IR(meta...) ir == nothing && return pr = Pipe(ir) - pr == nothing && return - for (v,st) in pr + isexpr(st.expr, :call) || continue ex = st.expr - ex = traverse_and_insert!(pr, v, ex) - - # Comment this to make it work - # pr[v] = Expr(:call, GlobalRef(Main, :cuize), st.expr.args...) + pr[v] = Expr(:call, GlobalRef(Main, :cuize), st.expr.args...) end - - return finish(pr) + return finish(pr) end - - - From a3246fdc417d0f0aadb6d92406e4c42f22a3351f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 5 Jun 2019 13:34:52 +0530 Subject: [PATCH 03/32] cleanup --- src/compiler.jl | 78 ++----------------------------------------------- 1 file changed, 2 insertions(+), 76 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index a2164e94..652fe92b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,67 +1,5 @@ -using IRTools: isexpr, IR, Variable, @dynamo, @code_ir, prewalk -using IRTools: xcall, prewalk, postwalk, insertafter!, exprtype -using IRTools: meta, NewVariable -using IRTools: Pipe, stmt, var, finish -using IRTools -import Adapt, GPUArrays - -# [4] TODO: clean iscuable fn - -# function iscuable(x) -# x isa NewVariable && return false -# x isa Variable && return true -# # x isa Argument && return true -# x isa Symbol && return false -# x isa QuoteNode && return false -# x isa GlobalRef && x.mod ∈ (CuArrays, Core, Adapt, GPUArrays) && return false - -# x isa Expr && x.args[1] isa GlobalRef && x.args[1].mod ∈ (CuArrays, Core, Adapt, GPUArrays) && return false -# x isa Expr && x.args[1] ∈ (:Base, :CuArrays, :Core, :Adapt) && return false -# x isa Expr && x.args[1] == :cu && return false -# x isa Expr && length(x.args) == 1 && return false -# x isa Expr && x.args[1] isa Variable && return false -# x isa Expr && x.args[2] isa QuoteNode && return true - -# x isa GlobalRef && x.mod ∉ (Base, Core) || return true -# isconst(x.mod, x.name) || return true -# x = getfield(x.mod, x.name) -# !(x isa Type || sizeof(x) == 0) -# end - - -# function traverse_and_insert!(ir, v, ex) -# ir[v] = postwalk(ex) do x -# iscuable(x) || return x -# insert!(ir, v, stmt(Expr(:call, GlobalRef(CuArrays, :cu), x))) -# end -# end - -# ∉(a,b) = !in(a,b) - -# @dynamo function cuize(meta) -# meta == nothing && return -# ir = IR(meta) -# ir == nothing && return -# pr = Pipe(ir) -# pr == nothing && return - -# for (v,st) in pr -# ex = st.expr -# ex = traverse_and_insert!(pr, v, ex) - -# # Comment this to make it work -# # pr[v] = Expr(:call, GlobalRef(Main, :cuize), st.expr.args...) -# end - -# return finish(pr) -# end - - -# function travserse(pr, x, st) -# pr[x] = postwalk(st) do x -# @show x -# end -# end +using IRTools: isexpr, IR, @dynamo +using IRTools: meta, Pipe, finish import Base.Broadcast.broadcasted import Base.Broadcast.materialize @@ -79,18 +17,6 @@ function cuize(::typeof(broadcasted), f, a...) broadcasted(f, b...) end -# @dynamo function cuize(meta) -# meta == nothing && return -# ir = IR(meta) -# for (x,st) in ir -# isexpr(st.expr, :call) || continue -# ex = st.expr -# ex = travserse(ir, x, ex) -# ir[x] = xcall(Main, :cuize, st.expr.args...) -# end -# ir -# end - @dynamo function cuize(meta...) meta == nothing && return ir = IR(meta...) From 7495b98410436bf940afa48ee12e80607923c77d Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 13 Jun 2019 20:10:22 +0530 Subject: [PATCH 04/32] cache cu'd arrays --- src/compiler.jl | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 652fe92b..5222b111 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4,12 +4,43 @@ using IRTools: meta, Pipe, finish import Base.Broadcast.broadcasted import Base.Broadcast.materialize +# array_bank = IdDict() +__context__ = IdDict() +# __context__ = WeakKeyDict() + + +for f in (:+, :-, :*, :/) + q = quote + function cuize(::typeof($f), a, b) + a = get_cached(__context__, a) + # @show typeof(a) + b = get_cached(__context__, b) + + c = $f(a, b) + get_cached(__context__, c) + end + end + eval(q) +end + +function get_cached(__context___, arr) + @show "here" + haskey(__context__, arr) ? + __context___[arr] : + cache(__context___, arr) + +end + +function cache(__context___, arr) + __context___[arr] = cu(arr) +end + # using `Array` instead of `cpu` here works but # causes tracking issues with TrackedArray -cuize(::typeof(*), a, b) = cpu(cu(a) * cu(b)) -cuize(::typeof(+), a, b) = cpu(cu(a) + cu(b)) -cuize(::typeof(-), a, b) = cpu(cu(a) - cu(b)) -cuize(::typeof(/), a, b) = cpu(cu(a) / cu(b)) +# cuize(::typeof(*), a, b) = cpu(cu(a) * cu(b)) +# cuize(::typeof(+), a, b) = cpu(cu(a) + cu(b)) +# cuize(::typeof(-), a, b) = cpu(cu(a) - cu(b)) +# cuize(::typeof(/), a, b) = cpu(cu(a) / cu(b)) cuize(::typeof(materialize), bc) = materialize(bc) function cuize(::typeof(broadcasted), f, a...) From e0965c17ddce4984e1ccb7c7e8a3380dbf5204a4 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 13 Jun 2019 21:46:00 +0530 Subject: [PATCH 05/32] use cached results --- src/compiler.jl | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5222b111..835775a1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4,35 +4,41 @@ using IRTools: meta, Pipe, finish import Base.Broadcast.broadcasted import Base.Broadcast.materialize -# array_bank = IdDict() + +# Hold all the arrays related to the op +array_bank = WeakKeyDict{AbstractArray, AbstractArray}() + +# Hold all the results related to the op, but permanently __context__ = IdDict() -# __context__ = WeakKeyDict() for f in (:+, :-, :*, :/) q = quote function cuize(::typeof($f), a, b) - a = get_cached(__context__, a) + a = get_cached(array_bank, a) # @show typeof(a) - b = get_cached(__context__, b) + b = get_cached(array_bank, b) - c = $f(a, b) - get_cached(__context__, c) + if haskey(__context__, ($f, a, b)) + __context__[($f, a, b)] + else + c = $f(a, b) + __context__[($f, a, b)] = c + end end end eval(q) end -function get_cached(__context___, arr) - @show "here" - haskey(__context__, arr) ? - __context___[arr] : - cache(__context___, arr) +function get_cached(array_bank, arr) + haskey(array_bank, arr) ? + array_bank[arr] : + cache(array_bank, arr) end -function cache(__context___, arr) - __context___[arr] = cu(arr) +function cache(array_bank, arr) + array_bank[arr] = cu(arr) end # using `Array` instead of `cpu` here works but From 82f8696f73698a0dc16ac70d4aa32ac358759d0b Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 14 Jun 2019 00:44:08 +0530 Subject: [PATCH 06/32] fixes --- src/compiler.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 835775a1..882f22e1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -15,14 +15,13 @@ __context__ = IdDict() for f in (:+, :-, :*, :/) q = quote function cuize(::typeof($f), a, b) - a = get_cached(array_bank, a) - # @show typeof(a) - b = get_cached(array_bank, b) if haskey(__context__, ($f, a, b)) - __context__[($f, a, b)] + return __context__[($f, a, b)] else - c = $f(a, b) + ga = get_cached(array_bank, a) + gb = get_cached(array_bank, b) + c = $f(ga, gb) __context__[($f, a, b)] = c end end From 44f802c151ca9013cf71566d60cff64dd8ee9836 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 18 Jun 2019 11:27:10 +0530 Subject: [PATCH 07/32] use IdDict over WeakKeyDict --- src/compiler.jl | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 882f22e1..bdc39e69 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6,24 +6,46 @@ import Base.Broadcast.materialize # Hold all the arrays related to the op -array_bank = WeakKeyDict{AbstractArray, AbstractArray}() +# array_bank = WeakKeyDict{AbstractArray, AbstractArray}() +# array_bank = WeakKeyDict() +array_bank = IdDict() # Hold all the results related to the op, but permanently __context__ = IdDict() +# function cache(__context__, f, args...) +# q = quote +# function cuize(::typeof($f), args...) +# if haskey(__context__, ($f, args...)) +# __context__[($f, args...)] +# else +# gargs = map(x -> get_cached(array_bank, x), args) +# c = $f(gargs...) +# __context__[($f, args...)] = c +# end +# end +# end +# eval(q) +# end + for f in (:+, :-, :*, :/) q = quote function cuize(::typeof($f), a, b) + # @show length(a) + # ga = get_cached(array_bank, a) + # gb = get_cached(array_bank, b) - if haskey(__context__, ($f, a, b)) - return __context__[($f, a, b)] - else + # c = $f(ga, gb) + # __context__[c] = c + # if haskey(__context__, ($f, a, b)) + # __context__[($f, a, b)] + # else ga = get_cached(array_bank, a) gb = get_cached(array_bank, b) c = $f(ga, gb) - __context__[($f, a, b)] = c - end + # __context__[($f, a, b)] = c + # end end end eval(q) From 601a6f438386cfdc1cdf3078bd2c2b17fd1c725c Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 18 Jun 2019 16:42:52 +0530 Subject: [PATCH 08/32] fixes --- src/compiler.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index bdc39e69..ddab2caf 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -29,6 +29,18 @@ __context__ = IdDict() # end +# function cuize(f, args...) +# q = quote +# function cuize(::typeof($f), args...) +# gargs = map(x -> get_cached(array_bank, x), args) +# c = $f(gargs...) +# end +# end +# eval(q) +# # cuize(f, args...) +# end + + for f in (:+, :-, :*, :/) q = quote function cuize(::typeof($f), a, b) From 9324434f078a013e8a2c6b251b9f828e8876243a Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 20 Jun 2019 21:37:39 +0530 Subject: [PATCH 09/32] separate concrete types as an opt in --- src/compiler.jl | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index ddab2caf..21a379dc 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3,6 +3,7 @@ using IRTools: meta, Pipe, finish import Base.Broadcast.broadcasted import Base.Broadcast.materialize +import Base.Broadcast.Broadcasted # Hold all the arrays related to the op @@ -43,7 +44,7 @@ __context__ = IdDict() for f in (:+, :-, :*, :/) q = quote - function cuize(::typeof($f), a, b) + function cuize(::typeof($f), a::AbstractArray, b::AbstractArray) # @show length(a) # ga = get_cached(array_bank, a) # gb = get_cached(array_bank, b) @@ -53,6 +54,7 @@ for f in (:+, :-, :*, :/) # if haskey(__context__, ($f, a, b)) # __context__[($f, a, b)] # else + # @timeit to "cuize" begin ga = get_cached(array_bank, a) gb = get_cached(array_bank, b) c = $f(ga, gb) @@ -64,12 +66,28 @@ for f in (:+, :-, :*, :/) end function get_cached(array_bank, arr) + # @show typeof(arr) + + # CuArrays can come up when you have outputs/ movements before ops + arr isa CuArray && return arr + + # Broadcasted objects are new everytime they're generated, ignore them + arr isa Broadcasted && return arr + haskey(array_bank, arr) ? array_bank[arr] : cache(array_bank, arr) end + +get_cached(x::AbstractArray) = get_cached(array_bank, x) + +# get_cached(array_bank, arr::TrackedArray) = get_cached(array_bank, Tracker.data(arr)) +get_cached(arr::TrackedArray) = get_cached(Tracker.data(arr)) + + + function cache(array_bank, arr) array_bank[arr] = cu(arr) end @@ -83,7 +101,7 @@ end cuize(::typeof(materialize), bc) = materialize(bc) function cuize(::typeof(broadcasted), f, a...) - b = map(cu, a) + b = map(x -> get_cached(array_bank, x), a) broadcasted(f, b...) end From efb50d1c869a0d265bf46b569f50cef56ebd09e2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 1 Jul 2019 17:41:39 +0530 Subject: [PATCH 10/32] cache objects explicitly and use --- src/compiler.jl | 89 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 74 insertions(+), 15 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 21a379dc..84faa505 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -9,7 +9,7 @@ import Base.Broadcast.Broadcasted # Hold all the arrays related to the op # array_bank = WeakKeyDict{AbstractArray, AbstractArray}() # array_bank = WeakKeyDict() -array_bank = IdDict() +array_bank = IdDict{AbstractArray, AbstractArray}() # Hold all the results related to the op, but permanently __context__ = IdDict() @@ -29,6 +29,7 @@ __context__ = IdDict() # eval(q) # end +# function get_cached end # function cuize(f, args...) # q = quote @@ -41,7 +42,6 @@ __context__ = IdDict() # # cuize(f, args...) # end - for f in (:+, :-, :*, :/) q = quote function cuize(::typeof($f), a::AbstractArray, b::AbstractArray) @@ -65,30 +65,26 @@ for f in (:+, :-, :*, :/) eval(q) end -function get_cached(array_bank, arr) - # @show typeof(arr) - +function get_cached(array_bank, arr::AbstractArray) # CuArrays can come up when you have outputs/ movements before ops arr isa CuArray && return arr # Broadcasted objects are new everytime they're generated, ignore them arr isa Broadcasted && return arr + arr isa TrackedArray && arr.data isa CuArray && return arr + haskey(array_bank, arr) ? array_bank[arr] : cache(array_bank, arr) - end - -get_cached(x::AbstractArray) = get_cached(array_bank, x) +# get_cached(x::AbstractArray) = get_cached(array_bank, parent(x)) # get_cached(array_bank, arr::TrackedArray) = get_cached(array_bank, Tracker.data(arr)) -get_cached(arr::TrackedArray) = get_cached(Tracker.data(arr)) +# get_cached(array_bank, arr::TrackedArray) = get_cached(Tracker.data(arr)) - - -function cache(array_bank, arr) +function cache(array_bank, arr::AbstractArray) array_bank[arr] = cu(arr) end @@ -100,9 +96,30 @@ end # cuize(::typeof(/), a, b) = cpu(cu(a) / cu(b)) cuize(::typeof(materialize), bc) = materialize(bc) -function cuize(::typeof(broadcasted), f, a...) - b = map(x -> get_cached(array_bank, x), a) - broadcasted(f, b...) +# function cuize(::typeof(getproperty), args...) +# # @show args +# getproperty(args...) +# end + +# function cuize(::typeof(getfield), args...) +# # @info "in getfield: $(typeof.(args))" +# getfield(args...) +# # try +# # getfield(args...) +# # catch e +# # @show typeof.(args) +# # throw() +# # end +# end + +function cuize(::typeof(broadcasted), f, args...) + gargs = map(x -> get_cached(array_bank, x), args) + broadcasted(f, gargs...) +end + +function cuize(::typeof(broadcast), f, args...) + gargs = map(x -> get_cached(array_bank, x), args) + broadcast(f, gargs...) end @dynamo function cuize(meta...) @@ -117,3 +134,45 @@ end end return finish(pr) end + + + + +################################################################### + +# function cuize(f, args...) +# T = Tuple{typeof(f), typeof.(args)...} +# q = quote +# function cuize(::typeof($f), args...) +# # @show typeof.(args) +# gargs = map(x -> get_cached(array_bank, x), args) +# # @show typeof.(gargs) +# # @timeit to "gf" gf = get_cached(array_bank, $f) +# gf = get_cached(array_bank, $f) +# # @show typeof($f) +# c = gf(gargs...) +# end +# end +# eval(q) + +# c = invoke(cuize, Tuple{typeof(f), typeof.(args)...}, (f, args...)) +# end + +function children(x::T, fs = fieldnames(T)) where T + map(f -> get_cached(array_bank, getproperty(x, f)), fs) +end + +mapchildren(x::T) where T = @eval $(Symbol(T.name))($(children(x))...) + +function get_cached(array_bank, x) + + x isa Broadcasted && return x + x isa Type && return x + x isa Function && return x + + haskey(__context__, x) && return __context__[x] + + __context__[x] = mapchildren(x) +end + +get_cached(array_bank, t::Tuple) = t From bc40e6d2c6bc9a5b49db1d0e6987e3f0cfb901bb Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 3 Jul 2019 16:09:22 +0530 Subject: [PATCH 11/32] clean interface --- src/compiler.jl | 125 ++++++++++-------------------------------------- 1 file changed, 26 insertions(+), 99 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 84faa505..8030b38d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -14,52 +14,12 @@ array_bank = IdDict{AbstractArray, AbstractArray}() # Hold all the results related to the op, but permanently __context__ = IdDict() -# function cache(__context__, f, args...) -# q = quote -# function cuize(::typeof($f), args...) -# if haskey(__context__, ($f, args...)) -# __context__[($f, args...)] -# else -# gargs = map(x -> get_cached(array_bank, x), args) -# c = $f(gargs...) -# __context__[($f, args...)] = c -# end -# end -# end -# eval(q) -# end - -# function get_cached end - -# function cuize(f, args...) -# q = quote -# function cuize(::typeof($f), args...) -# gargs = map(x -> get_cached(array_bank, x), args) -# c = $f(gargs...) -# end -# end -# eval(q) -# # cuize(f, args...) -# end - for f in (:+, :-, :*, :/) q = quote function cuize(::typeof($f), a::AbstractArray, b::AbstractArray) - # @show length(a) - # ga = get_cached(array_bank, a) - # gb = get_cached(array_bank, b) - - # c = $f(ga, gb) - # __context__[c] = c - # if haskey(__context__, ($f, a, b)) - # __context__[($f, a, b)] - # else - # @timeit to "cuize" begin ga = get_cached(array_bank, a) gb = get_cached(array_bank, b) c = $f(ga, gb) - # __context__[($f, a, b)] = c - # end end end eval(q) @@ -68,10 +28,6 @@ end function get_cached(array_bank, arr::AbstractArray) # CuArrays can come up when you have outputs/ movements before ops arr isa CuArray && return arr - - # Broadcasted objects are new everytime they're generated, ignore them - arr isa Broadcasted && return arr - arr isa TrackedArray && arr.data isa CuArray && return arr haskey(array_bank, arr) ? @@ -79,39 +35,12 @@ function get_cached(array_bank, arr::AbstractArray) cache(array_bank, arr) end -# get_cached(x::AbstractArray) = get_cached(array_bank, parent(x)) - -# get_cached(array_bank, arr::TrackedArray) = get_cached(array_bank, Tracker.data(arr)) -# get_cached(array_bank, arr::TrackedArray) = get_cached(Tracker.data(arr)) - -function cache(array_bank, arr::AbstractArray) +function cache(array_bank::IdDict{T,T}, arr::AbstractArray) where T <: AbstractArray array_bank[arr] = cu(arr) end -# using `Array` instead of `cpu` here works but -# causes tracking issues with TrackedArray -# cuize(::typeof(*), a, b) = cpu(cu(a) * cu(b)) -# cuize(::typeof(+), a, b) = cpu(cu(a) + cu(b)) -# cuize(::typeof(-), a, b) = cpu(cu(a) - cu(b)) -# cuize(::typeof(/), a, b) = cpu(cu(a) / cu(b)) cuize(::typeof(materialize), bc) = materialize(bc) -# function cuize(::typeof(getproperty), args...) -# # @show args -# getproperty(args...) -# end - -# function cuize(::typeof(getfield), args...) -# # @info "in getfield: $(typeof.(args))" -# getfield(args...) -# # try -# # getfield(args...) -# # catch e -# # @show typeof.(args) -# # throw() -# # end -# end - function cuize(::typeof(broadcasted), f, args...) gargs = map(x -> get_cached(array_bank, x), args) broadcasted(f, gargs...) @@ -135,28 +64,22 @@ end return finish(pr) end - - - ################################################################### -# function cuize(f, args...) -# T = Tuple{typeof(f), typeof.(args)...} -# q = quote -# function cuize(::typeof($f), args...) -# # @show typeof.(args) -# gargs = map(x -> get_cached(array_bank, x), args) -# # @show typeof.(gargs) -# # @timeit to "gf" gf = get_cached(array_bank, $f) -# gf = get_cached(array_bank, $f) -# # @show typeof($f) -# c = gf(gargs...) -# end -# end -# eval(q) - -# c = invoke(cuize, Tuple{typeof(f), typeof.(args)...}, (f, args...)) -# end +# Makes things work, but breaks continuity +# Gets called after every line of IR, make it stop +function cuize(f::T, arg1, args...) where T + q = quote + function cuize(::typeof($f), arg1, args...) + garg1 = get_cached(arg1) + gargs = map(get_cached, args) + gf = get_cached($f) + c = gf(garg1, gargs...) + end + end + eval(q) + c = invoke(cuize, Tuple{typeof(f), typeof(arg1), typeof.(args)...}, (f, arg1, args...)) +end function children(x::T, fs = fieldnames(T)) where T map(f -> get_cached(array_bank, getproperty(x, f)), fs) @@ -164,15 +87,19 @@ end mapchildren(x::T) where T = @eval $(Symbol(T.name))($(children(x))...) -function get_cached(array_bank, x) - - x isa Broadcasted && return x - x isa Type && return x - x isa Function && return x - +function get_cached(__context__, x::T) where T haskey(__context__, x) && return __context__[x] + x isa CuArray && return x + x isa TrackedArray && x.data isa CuArray && return x __context__[x] = mapchildren(x) end -get_cached(array_bank, t::Tuple) = t +get_cached(array_bank, t::Union{Type, Function, Broadcasted, T}) where {T <: Real} = t +get_cached(array_bank, t::Union{Tuple,NamedTuple}) = map(get_cached, t) + +function get_cached(x::T) where T + T <: AbstractArray && return get_cached(array_bank, x) + isstructtype(T) && return get_cached(__context__, x) + get_cached(array_bank, x) +end From 7ecfbc9af73ceb5db90b882ae8d7e2b941d0226e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 6 Jul 2019 16:37:30 +0530 Subject: [PATCH 12/32] dont cache objects --- src/compiler.jl | 115 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 87 insertions(+), 28 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8030b38d..1a83fe8b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -25,7 +25,7 @@ for f in (:+, :-, :*, :/) eval(q) end -function get_cached(array_bank, arr::AbstractArray) +function get_cached(array_bank::IdDict{T,T}, arr::AbstractArray) where T <: AbstractArray # CuArrays can come up when you have outputs/ movements before ops arr isa CuArray && return arr arr isa TrackedArray && arr.data isa CuArray && return arr @@ -35,11 +35,35 @@ function get_cached(array_bank, arr::AbstractArray) cache(array_bank, arr) end -function cache(array_bank::IdDict{T,T}, arr::AbstractArray) where T <: AbstractArray +function cache(array_bank::IdDict{T,T}, arr::AbstractArray{<:Real}) where T <: AbstractArray + array_bank[arr] = cu(arr) end -cuize(::typeof(materialize), bc) = materialize(bc) +cache(array_bank::IdDict{T,T}, arr::AbstractArray) where T <: AbstractArray = map(get_cached, arr) +# function cache(array_bank::IdDict, s::K) where K <: AbstractSet} +# t = K() +# for p in s +# push!(t, get_cached(p)) +# end +# t +# end + +function cuable(st) + args = st.expr.args + + # @show typeof.(args) + # @show tellmethetype.(args) + flag = true + for x in args + if x isa GlobalRef && x.mod == NNlib && x.name == :DenseConvDims + flag = false + break + end + end + # map(x -> x isa GlobalRef ? @show x.name : x, args) + flag +end function cuize(::typeof(broadcasted), f, args...) gargs = map(x -> get_cached(array_bank, x), args) @@ -59,47 +83,82 @@ end for (v,st) in pr isexpr(st.expr, :call) || continue ex = st.expr - pr[v] = Expr(:call, GlobalRef(Main, :cuize), st.expr.args...) + # if cuable(st) + pr[v] = Expr(:call, GlobalRef(Main, :cuize), st.expr.args...) + # else + # @show "trying to write wrapper" + # pr[v] = Expr(:call, GlobalRef(Main, :fcuize), st.expr.args...) + # @show "written wrapper" + # end end return finish(pr) end +cuize(::typeof(setindex!), ::Tuple, args...) = tuple(args...) + ################################################################### -# Makes things work, but breaks continuity +# Makes things work, but breaks continuity and recursion # Gets called after every line of IR, make it stop -function cuize(f::T, arg1, args...) where T - q = quote - function cuize(::typeof($f), arg1, args...) - garg1 = get_cached(arg1) - gargs = map(get_cached, args) - gf = get_cached($f) - c = gf(garg1, gargs...) - end - end - eval(q) - c = invoke(cuize, Tuple{typeof(f), typeof(arg1), typeof.(args)...}, (f, arg1, args...)) -end +# Basically a horrible hack around getting constructors to work +# function cuize(f, arg1, args...) +# # @show f +# # if Symbol(f) == :DenseConvDims + +# q = quote +# function cuize(::typeof($f), arg1, args...) +# garg1 = get_cached(arg1) +# gargs = map(get_cached, args) +# gf = get_cached($f) +# c = gf(garg1, gargs...) +# end +# end +# eval(q) +# # end + +# c = invoke(cuize, Tuple{typeof(f), typeof(arg1), typeof.(args)...}, (f, arg1, args...)) +# end function children(x::T, fs = fieldnames(T)) where T - map(f -> get_cached(array_bank, getproperty(x, f)), fs) + map(f -> get_cached(getproperty(x, f)), fs) # get_cached -> get_cached(array_bank, getproperty...) end +children(x::Tuple) = map(get_cached, x) +# function children(s::T) where T <:AbstractSet +# t = T() +# for p in s +# push!(t, get_cached(p)) +# end +# t +# end + + mapchildren(x::T) where T = @eval $(Symbol(T.name))($(children(x))...) +# mapchildren(x::T) where T<:AbstractSet = children(x) -function get_cached(__context__, x::T) where T - haskey(__context__, x) && return __context__[x] - x isa CuArray && return x - x isa TrackedArray && x.data isa CuArray && return x +# function get_cached(__context__, x::T) where T +# haskey(__context__, x) && return __context__[x] +# x isa CuArray && return x +# x isa TrackedArray && x.data isa CuArray && return x - __context__[x] = mapchildren(x) -end +# __context__[x] = mapchildren(x) +# end -get_cached(array_bank, t::Union{Type, Function, Broadcasted, T}) where {T <: Real} = t -get_cached(array_bank, t::Union{Tuple,NamedTuple}) = map(get_cached, t) +get_cached(array_bank, t::Union{Type, Function, Broadcasted, Symbol, T}) where {T <: Real} = t +get_cached(array_bank, t::Union{Tuple, NamedTuple}) = map(get_cached, t) function get_cached(x::T) where T - T <: AbstractArray && return get_cached(array_bank, x) - isstructtype(T) && return get_cached(__context__, x) + # T <: AbstractArray && return get_cached(array_bank, x) + # isstructtype(T) && return get_cached(__context__, x) get_cached(array_bank, x) end + +function noop_pass(f, args...) + @eval cuize(::typeof($f), args...) = $f(args...) +end + +noop_pass.((getproperty, materialize, )) + +@generated function tellmethetype(x) + x +end From b077a9889a1be38541956445f740574cc3a57fba Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 16 Jul 2019 12:56:04 +0530 Subject: [PATCH 13/32] use cached nametuples --- src/compiler.jl | 269 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 208 insertions(+), 61 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 1a83fe8b..54dde2ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,5 +1,5 @@ using IRTools: isexpr, IR, @dynamo -using IRTools: meta, Pipe, finish +using IRTools: meta, Pipe, finish, Variable import Base.Broadcast.broadcasted import Base.Broadcast.materialize @@ -10,9 +10,10 @@ import Base.Broadcast.Broadcasted # array_bank = WeakKeyDict{AbstractArray, AbstractArray}() # array_bank = WeakKeyDict() array_bank = IdDict{AbstractArray, AbstractArray}() +# array_bank = IdDict() -# Hold all the results related to the op, but permanently -__context__ = IdDict() +# Hold all the objects related to the op +obs = IdDict() for f in (:+, :-, :*, :/) q = quote @@ -41,7 +42,7 @@ function cache(array_bank::IdDict{T,T}, arr::AbstractArray{<:Real}) where T <: A end cache(array_bank::IdDict{T,T}, arr::AbstractArray) where T <: AbstractArray = map(get_cached, arr) -# function cache(array_bank::IdDict, s::K) where K <: AbstractSet} +# function cache(array_bank::IdDict, s::K) where K <: AbstractSet # t = K() # for p in s # push!(t, get_cached(p)) @@ -49,78 +50,109 @@ cache(array_bank::IdDict{T,T}, arr::AbstractArray) where T <: AbstractArray = ma # t # end -function cuable(st) - args = st.expr.args - +function cuize(::typeof(broadcasted), f, args...) + # @show f # @show typeof.(args) - # @show tellmethetype.(args) - flag = true - for x in args - if x isa GlobalRef && x.mod == NNlib && x.name == :DenseConvDims - flag = false - break - end - end - # map(x -> x isa GlobalRef ? @show x.name : x, args) - flag + gargs = map(x -> get_cached(array_bank, x), args) + Main.broadcasted(f, gargs...) end -function cuize(::typeof(broadcasted), f, args...) - gargs = map(x -> get_cached(array_bank, x), args) - broadcasted(f, gargs...) +function cuize(::typeof(getproperty), o, s::Symbol) + getproperty(o, s) |> get_cached end +# function cuize(::typeof(getproperty), o, s::Symbol) +# getproperty(get_cached(o), s) +# end + function cuize(::typeof(broadcast), f, args...) gargs = map(x -> get_cached(array_bank, x), args) broadcast(f, gargs...) end +# function cuize(::typeof(reshape), arr::AbstractArray, args...; kwargs...) +# @show "in reshape" +# reshape(get_cached(arr), args...; kwargs...) +# end + +# cuize(::typeof(reshape), arr::Base.OneTo, args...; kwargs...) = reshape(get_cached(collect(arr)), args...; kwargs...) + +# function h!(ir) +# args = IRTools.arguments(ir) + +# for a in args +# pushfirst!(ir, :(get_cached($a))) +# # IRTools.deletearg!(ir, a) +# # ir[a] = :(get_cached($a)) +# end +# ir +# end + @dynamo function cuize(meta...) meta == nothing && return ir = IR(meta...) + # @show ir ir == nothing && return + # args = IRTools.arguments(ir) + # ir = IRTools.postwalk(ir) do x + # x in args && return Expr(:call, GlobalRef(Main, :get_cached), x) + # return x + # end + pr = Pipe(ir) for (v,st) in pr isexpr(st.expr, :call) || continue ex = st.expr - # if cuable(st) - pr[v] = Expr(:call, GlobalRef(Main, :cuize), st.expr.args...) - # else - # @show "trying to write wrapper" - # pr[v] = Expr(:call, GlobalRef(Main, :fcuize), st.expr.args...) - # @show "written wrapper" + + # ex = IRTools.postwalk(ex) do x + # x in args && reutrn :(get_cached(x)) + # return x # end + # @show ex + # ex isa Nothing && continue + # ex = IRTools.postwalk(ex) do x + + # i = findall(y -> y == x, ex.args) + # if length(i) > 0 + # i = i[1] + # temp = Expr(:call, GlobalRef(Main, :get_cached), ex.args[i]) + # ex.args[i] = GlobalRef(Main, :temp) + # end + # # temp = Expr(:call, GlobalRef(Main, :get_cached), arg) + # # GlobalRef(Main, :temp) + # end + # ex = IRTools.postwalk(ex) do x + # x isa GlobalRef && x in ex.args && return IRTools.xcall(Main, :get_cached, x) + # x + # end + + + pr[v] = Expr(:call, GlobalRef(Main, :cuize), ex.args...) + end return finish(pr) end -cuize(::typeof(setindex!), ::Tuple, args...) = tuple(args...) -################################################################### +# function cuize(::typeof(reshape), args...) +# reshape(map(get_cached, args)...) +# end -# Makes things work, but breaks continuity and recursion -# Gets called after every line of IR, make it stop -# Basically a horrible hack around getting constructors to work -# function cuize(f, arg1, args...) -# # @show f -# # if Symbol(f) == :DenseConvDims - -# q = quote -# function cuize(::typeof($f), arg1, args...) -# garg1 = get_cached(arg1) -# gargs = map(get_cached, args) -# gf = get_cached($f) -# c = gf(garg1, gargs...) -# end -# end -# eval(q) -# # end +cuize(::typeof(setindex!), ::Tuple, args...) = tuple(args...) -# c = invoke(cuize, Tuple{typeof(f), typeof(arg1), typeof.(args)...}, (f, arg1, args...)) -# end +################################################################### function children(x::T, fs = fieldnames(T)) where T - map(f -> get_cached(getproperty(x, f)), fs) # get_cached -> get_cached(array_bank, getproperty...) + # map(f -> get_cached(getproperty(x, f)), fs) # get_cached -> get_cached(array_bank, getproperty...) + # q = quote + # function cuize(c::$(nameof(T)), args...) + # gargs = map(get_cached, args) + # g_c = $(nameof(T))(get_cached(c)...) + # g_c(gargs...) + # end + # end + # eval(q) + (; zip(fs, map(f -> get_cached(getproperty(x, f)), fs))...) end children(x::Tuple) = map(get_cached, x) @@ -133,32 +165,147 @@ children(x::Tuple) = map(get_cached, x) # end -mapchildren(x::T) where T = @eval $(Symbol(T.name))($(children(x))...) +# mapchildren(x::T) where T = @eval $(Symbol(T.name))($(children(x))...) +mapchildren(x::T) where T = children(x) # mapchildren(x::T) where T<:AbstractSet = children(x) -# function get_cached(__context__, x::T) where T -# haskey(__context__, x) && return __context__[x] -# x isa CuArray && return x -# x isa TrackedArray && x.data isa CuArray && return x +function get_cached(obs, x::T) where T + haskey(obs, x) && return obs[x] + x isa CuArray && return x + x isa TrackedArray && x.data isa CuArray && return x -# __context__[x] = mapchildren(x) -# end + obs[x] = mapchildren(x) +end -get_cached(array_bank, t::Union{Type, Function, Broadcasted, Symbol, T}) where {T <: Real} = t +get_cached(array_bank, t::Union{Type, Function, Broadcasted, Symbol, Module, Nothing, Missing, Ptr, T}) where {T <: Real} = t get_cached(array_bank, t::Union{Tuple, NamedTuple}) = map(get_cached, t) function get_cached(x::T) where T - # T <: AbstractArray && return get_cached(array_bank, x) - # isstructtype(T) && return get_cached(__context__, x) + T <: AbstractArray && return get_cached(array_bank, x) + isstructtype(T) && return x # get_cached(obs, x) get_cached(array_bank, x) end +""" + Disable `cuize` for a function +""" function noop_pass(f, args...) @eval cuize(::typeof($f), args...) = $f(args...) end -noop_pass.((getproperty, materialize, )) +noop_pass.((materialize, )) + +# This allows capturing calls that have at least one argument +# Use case handled is that a lot of calls to `Main.getfield(Symbol("#xx#xx"))` +# like cases also get covered with just `args...` which would be nice to avoid +# CANT DO THIS - OVERLOADS CUIZE AND BREAKS RECURSION +# function cuize(f::Function, xs, args...) +# gxs = get_cached(xs) +# f(gxs, map(get_cached, args)...) +# end +# function cuize(o, xs, args...) +# gxs = get_cached(xs) +# o(gxs, map(get_cached, args)...) +# end + + +cuize(::typeof(get_cached), args...) = get_cached(args...) + + +################################################### + +# Adding get_cached calls in dynamo - +# Before making Pipe - core dumped; julia crashed +# Inside Pipe - Bad compiler errors argextype ones +# Need to find a way to catch arguments +# + +# cuize(a) = a() + -@generated function tellmethetype(x) - x +# function set_cuize(f::T) where T +# # isstructtype(typeof(f)) +# @show T.name +# q = quote +# function cuize(::$(T.name), args...) +# gargs = map(get_cached, args) +# cuize() do +# ($f::$(T.name))(gargs...) +# end +# end +# end +# @show q +# # eval(q) +# end + +# function modifyex(ex) +# @assert IRTools.isexpr(ex, :call) +# args = ex.args +# fs = Symbol[args[1]] +# for a in args[2:end] +# if a isa Expr +# @assert IRTools.isexpr(a, :call) +# push!(fs, a.args[1]) +# end +# end +# mex = IRTools.postwalk(ex) do x +# x isa Expr && return x +# if x in fs +# return x +# else +# return :(get_cached($x)) +# end +# end +# mex +# end + +# macro cuize(ex) +# mex = modifyex(ex) +# return :(cuize() do +# $mex +# end) +# end + + + +function makechildren(T::Type, nt::NamedTuple) + eval(nameof(T))(nt...) end +# function cuize(f::Function, xs, args...) +# # @show f +# gxs = get_cached(xs) +# gargs = map(get_cached, args) +# f(gxs, gargs...) +# end + + +# Functions called inside `cuize` aren't executed as part of the context +# So any assumptions made inside the context (`getproperty`, for eg) will +# not Hold +# Thus we need the actual objects when trying to call the objects, as opposed to +# continuing inside the context where we can pick fields up from a NamedTuple +# Without this limitation, we can avoid caching the structs themselves +function cuize(::typeof(invoke), f::T, types, args...) where T + @timeit to "bad stuff" gf = f isa Function ? f : makechildren(T, get_cached(obs, f)) + invoke(gf, types, map(get_cached, args)...) + end + + + +# Same problem as recursing; works for basic stuff +# if objects are cached as objects +# function cuize(f::Function, xs, args...) +# # @show f +# gxs = get_cached(xs) +# gargs = map(get_cached, args) +# f(gxs, gargs...) +# end + + + + +# function get_cached(__context__::Function, xs) +# __context__() do +# xs +# end +# end \ No newline at end of file From f3a484f117db5847cc070b56337997a0eed3fa25 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 16 Jul 2019 17:02:20 +0530 Subject: [PATCH 14/32] cleanup --- src/compiler.jl | 168 ++---------------------------------------------- 1 file changed, 4 insertions(+), 164 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 54dde2ca..6c38caa7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -10,7 +10,6 @@ import Base.Broadcast.Broadcasted # array_bank = WeakKeyDict{AbstractArray, AbstractArray}() # array_bank = WeakKeyDict() array_bank = IdDict{AbstractArray, AbstractArray}() -# array_bank = IdDict() # Hold all the objects related to the op obs = IdDict() @@ -51,8 +50,6 @@ cache(array_bank::IdDict{T,T}, arr::AbstractArray) where T <: AbstractArray = ma # end function cuize(::typeof(broadcasted), f, args...) - # @show f - # @show typeof.(args) gargs = map(x -> get_cached(array_bank, x), args) Main.broadcasted(f, gargs...) end @@ -66,92 +63,32 @@ end # end function cuize(::typeof(broadcast), f, args...) + @show f gargs = map(x -> get_cached(array_bank, x), args) broadcast(f, gargs...) end -# function cuize(::typeof(reshape), arr::AbstractArray, args...; kwargs...) -# @show "in reshape" -# reshape(get_cached(arr), args...; kwargs...) -# end - -# cuize(::typeof(reshape), arr::Base.OneTo, args...; kwargs...) = reshape(get_cached(collect(arr)), args...; kwargs...) - -# function h!(ir) -# args = IRTools.arguments(ir) - -# for a in args -# pushfirst!(ir, :(get_cached($a))) -# # IRTools.deletearg!(ir, a) -# # ir[a] = :(get_cached($a)) -# end -# ir -# end - @dynamo function cuize(meta...) meta == nothing && return ir = IR(meta...) - # @show ir ir == nothing && return - # args = IRTools.arguments(ir) - # ir = IRTools.postwalk(ir) do x - # x in args && return Expr(:call, GlobalRef(Main, :get_cached), x) - # return x - # end pr = Pipe(ir) for (v,st) in pr isexpr(st.expr, :call) || continue ex = st.expr - # ex = IRTools.postwalk(ex) do x - # x in args && reutrn :(get_cached(x)) - # return x - # end - # @show ex - # ex isa Nothing && continue - # ex = IRTools.postwalk(ex) do x - - # i = findall(y -> y == x, ex.args) - # if length(i) > 0 - # i = i[1] - # temp = Expr(:call, GlobalRef(Main, :get_cached), ex.args[i]) - # ex.args[i] = GlobalRef(Main, :temp) - # end - # # temp = Expr(:call, GlobalRef(Main, :get_cached), arg) - # # GlobalRef(Main, :temp) - # end - # ex = IRTools.postwalk(ex) do x - # x isa GlobalRef && x in ex.args && return IRTools.xcall(Main, :get_cached, x) - # x - # end - - pr[v] = Expr(:call, GlobalRef(Main, :cuize), ex.args...) end return finish(pr) end - -# function cuize(::typeof(reshape), args...) -# reshape(map(get_cached, args)...) -# end - cuize(::typeof(setindex!), ::Tuple, args...) = tuple(args...) ################################################################### function children(x::T, fs = fieldnames(T)) where T - # map(f -> get_cached(getproperty(x, f)), fs) # get_cached -> get_cached(array_bank, getproperty...) - # q = quote - # function cuize(c::$(nameof(T)), args...) - # gargs = map(get_cached, args) - # g_c = $(nameof(T))(get_cached(c)...) - # g_c(gargs...) - # end - # end - # eval(q) (; zip(fs, map(f -> get_cached(getproperty(x, f)), fs))...) end @@ -195,88 +132,11 @@ end noop_pass.((materialize, )) -# This allows capturing calls that have at least one argument -# Use case handled is that a lot of calls to `Main.getfield(Symbol("#xx#xx"))` -# like cases also get covered with just `args...` which would be nice to avoid -# CANT DO THIS - OVERLOADS CUIZE AND BREAKS RECURSION -# function cuize(f::Function, xs, args...) -# gxs = get_cached(xs) -# f(gxs, map(get_cached, args)...) -# end -# function cuize(o, xs, args...) -# gxs = get_cached(xs) -# o(gxs, map(get_cached, args)...) -# end - - cuize(::typeof(get_cached), args...) = get_cached(args...) - -################################################### - -# Adding get_cached calls in dynamo - -# Before making Pipe - core dumped; julia crashed -# Inside Pipe - Bad compiler errors argextype ones -# Need to find a way to catch arguments -# - -# cuize(a) = a() - - -# function set_cuize(f::T) where T -# # isstructtype(typeof(f)) -# @show T.name -# q = quote -# function cuize(::$(T.name), args...) -# gargs = map(get_cached, args) -# cuize() do -# ($f::$(T.name))(gargs...) -# end -# end -# end -# @show q -# # eval(q) -# end - -# function modifyex(ex) -# @assert IRTools.isexpr(ex, :call) -# args = ex.args -# fs = Symbol[args[1]] -# for a in args[2:end] -# if a isa Expr -# @assert IRTools.isexpr(a, :call) -# push!(fs, a.args[1]) -# end -# end -# mex = IRTools.postwalk(ex) do x -# x isa Expr && return x -# if x in fs -# return x -# else -# return :(get_cached($x)) -# end -# end -# mex -# end - -# macro cuize(ex) -# mex = modifyex(ex) -# return :(cuize() do -# $mex -# end) -# end - - - function makechildren(T::Type, nt::NamedTuple) eval(nameof(T))(nt...) end -# function cuize(f::Function, xs, args...) -# # @show f -# gxs = get_cached(xs) -# gargs = map(get_cached, args) -# f(gxs, gargs...) -# end # Functions called inside `cuize` aren't executed as part of the context @@ -286,26 +146,6 @@ end # continuing inside the context where we can pick fields up from a NamedTuple # Without this limitation, we can avoid caching the structs themselves function cuize(::typeof(invoke), f::T, types, args...) where T - @timeit to "bad stuff" gf = f isa Function ? f : makechildren(T, get_cached(obs, f)) - invoke(gf, types, map(get_cached, args)...) - end - - - -# Same problem as recursing; works for basic stuff -# if objects are cached as objects -# function cuize(f::Function, xs, args...) -# # @show f -# gxs = get_cached(xs) -# gargs = map(get_cached, args) -# f(gxs, gargs...) -# end - - - - -# function get_cached(__context__::Function, xs) -# __context__() do -# xs -# end -# end \ No newline at end of file + gf = f isa Function ? f : makechildren(T, get_cached(obs, f)) + invoke(gf, types, map(get_cached, args)...) +end \ No newline at end of file From fb5b05899d5ef662bddef75b5e6a39932ef10a61 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 16 Jul 2019 10:14:58 -0400 Subject: [PATCH 15/32] compiler.jl -> context.jl --- src/CuArrays.jl | 2 ++ src/{compiler.jl => context.jl} | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) rename src/{compiler.jl => context.jl} (99%) diff --git a/src/CuArrays.jl b/src/CuArrays.jl index dcb27410..d249689a 100644 --- a/src/CuArrays.jl +++ b/src/CuArrays.jl @@ -51,6 +51,8 @@ libcudnn !== nothing && include("dnn/CUDNN.jl") include("nnlib.jl") +include("context.jl") + include("deprecated.jl") function __init__() diff --git a/src/compiler.jl b/src/context.jl similarity index 99% rename from src/compiler.jl rename to src/context.jl index 6c38caa7..e2d2b4a3 100644 --- a/src/compiler.jl +++ b/src/context.jl @@ -148,4 +148,4 @@ end function cuize(::typeof(invoke), f::T, types, args...) where T gf = f isa Function ? f : makechildren(T, get_cached(obs, f)) invoke(gf, types, map(get_cached, args)...) -end \ No newline at end of file +end From d44a86d8313017feff1ce8d70d23b8cbfa7b84c5 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 16 Jul 2019 10:19:30 -0400 Subject: [PATCH 16/32] add IRTools dependency --- Project.toml | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 3585a65f..fb9fedda 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" @@ -18,6 +19,15 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +[compat] +Adapt = "0.4" +CUDAapi = "0.5.3, 0.6, 1.0" +CUDAdrv = "3.0" +CUDAnative = "2.0" +GPUArrays = "0.7" +NNlib = "0.6" +julia = "1.0" + [extras] FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -25,12 +35,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "FFTW", "ForwardDiff"] - -[compat] -julia = "1.0" -CUDAnative = "2.0" -CUDAdrv = "3.0" -CUDAapi = "0.5.3, 0.6, 1.0" -NNlib = "0.6" -GPUArrays = "0.7" -Adapt = "0.4" From 8997b6c59287c19db05371dac8238943382612c7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 16 Jul 2019 10:34:46 -0400 Subject: [PATCH 17/32] simple test --- src/CuArrays.jl | 2 +- src/context.jl | 35 ++++++++++++++++------------------- test/context.jl | 8 ++++++++ test/runtests.jl | 1 + 4 files changed, 26 insertions(+), 20 deletions(-) create mode 100644 test/context.jl diff --git a/src/CuArrays.jl b/src/CuArrays.jl index d249689a..382a2947 100644 --- a/src/CuArrays.jl +++ b/src/CuArrays.jl @@ -4,7 +4,7 @@ using CUDAdrv, CUDAnative using GPUArrays -export CuArray, CuVector, CuMatrix, CuVecOrMat, cu +export CuArray, CuVector, CuMatrix, CuVecOrMat, cu, cuda import LinearAlgebra diff --git a/src/context.jl b/src/context.jl index e2d2b4a3..1283f600 100644 --- a/src/context.jl +++ b/src/context.jl @@ -7,40 +7,34 @@ import Base.Broadcast.Broadcasted # Hold all the arrays related to the op -# array_bank = WeakKeyDict{AbstractArray, AbstractArray}() -# array_bank = WeakKeyDict() -array_bank = IdDict{AbstractArray, AbstractArray}() +# TODO: this should be a context +const array_bank = WeakKeyDict{AbstractArray, AbstractArray}() # Hold all the objects related to the op -obs = IdDict() +# obs = IdDict() for f in (:+, :-, :*, :/) - q = quote - function cuize(::typeof($f), a::AbstractArray, b::AbstractArray) - ga = get_cached(array_bank, a) - gb = get_cached(array_bank, b) - c = $f(ga, gb) - end - end - eval(q) + @eval function cuize(::typeof($f), a::AbstractArray, b::AbstractArray) + ga = get_cached(array_bank, a) + gb = get_cached(array_bank, b) + c = $f(ga, gb) + end end -function get_cached(array_bank::IdDict{T,T}, arr::AbstractArray) where T <: AbstractArray +function get_cached(array_bank, arr::AbstractArray) # CuArrays can come up when you have outputs/ movements before ops arr isa CuArray && return arr - arr isa TrackedArray && arr.data isa CuArray && return arr haskey(array_bank, arr) ? array_bank[arr] : cache(array_bank, arr) end -function cache(array_bank::IdDict{T,T}, arr::AbstractArray{<:Real}) where T <: AbstractArray - +function cache(array_bank, arr::AbstractArray{<:Real}) where T <: AbstractArray array_bank[arr] = cu(arr) end -cache(array_bank::IdDict{T,T}, arr::AbstractArray) where T <: AbstractArray = map(get_cached, arr) +cache(array_bank, arr::AbstractArray) = map(get_cached, arr) # function cache(array_bank::IdDict, s::K) where K <: AbstractSet # t = K() # for p in s @@ -78,7 +72,7 @@ end isexpr(st.expr, :call) || continue ex = st.expr - pr[v] = Expr(:call, GlobalRef(Main, :cuize), ex.args...) + pr[v] = Expr(:call, GlobalRef(CuArrays, :cuize), ex.args...) end return finish(pr) @@ -109,7 +103,6 @@ mapchildren(x::T) where T = children(x) function get_cached(obs, x::T) where T haskey(obs, x) && return obs[x] x isa CuArray && return x - x isa TrackedArray && x.data isa CuArray && return x obs[x] = mapchildren(x) end @@ -149,3 +142,7 @@ function cuize(::typeof(invoke), f::T, types, args...) where T gf = f isa Function ? f : makechildren(T, get_cached(obs, f)) invoke(gf, types, map(get_cached, args)...) end + +function cuda(f) + cuize(f) +end diff --git a/test/context.jl b/test/context.jl new file mode 100644 index 00000000..c9fa7fdf --- /dev/null +++ b/test/context.jl @@ -0,0 +1,8 @@ +using CuArrays, Test + +W = rand(5, 5) +b = rand(5) + +@test cuda() do + W*b +end ≈ W*b diff --git a/test/runtests.jl b/test/runtests.jl index f41a1bba..5f50e9f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,7 @@ include("sparse.jl") include("solver.jl") include("sparse_solver.jl") include("dnn.jl") +include("context.jl") CuArrays.pool_status() CuArrays.pool_timings() From 99425a1496369b0392b6e9ad79157f51ab630449 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 16 Jul 2019 11:15:18 -0400 Subject: [PATCH 18/32] restrict to arrays --- src/context.jl | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/src/context.jl b/src/context.jl index 1283f600..78d3cdad 100644 --- a/src/context.jl +++ b/src/context.jl @@ -8,41 +8,25 @@ import Base.Broadcast.Broadcasted # Hold all the arrays related to the op # TODO: this should be a context -const array_bank = WeakKeyDict{AbstractArray, AbstractArray}() +const array_bank = WeakKeyDict{Array,CuArray}() # Hold all the objects related to the op # obs = IdDict() for f in (:+, :-, :*, :/) - @eval function cuize(::typeof($f), a::AbstractArray, b::AbstractArray) + @eval function cuize(::typeof($f), a::Array, b::Array) ga = get_cached(array_bank, a) gb = get_cached(array_bank, b) c = $f(ga, gb) end end -function get_cached(array_bank, arr::AbstractArray) - # CuArrays can come up when you have outputs/ movements before ops - arr isa CuArray && return arr - +function get_cached(array_bank, arr::Array{T,N})::CuArray{T,N} where {T,N} haskey(array_bank, arr) ? array_bank[arr] : - cache(array_bank, arr) -end - -function cache(array_bank, arr::AbstractArray{<:Real}) where T <: AbstractArray - array_bank[arr] = cu(arr) + (array_bank[arr] = CuArray(arr)) end -cache(array_bank, arr::AbstractArray) = map(get_cached, arr) -# function cache(array_bank::IdDict, s::K) where K <: AbstractSet -# t = K() -# for p in s -# push!(t, get_cached(p)) -# end -# t -# end - function cuize(::typeof(broadcasted), f, args...) gargs = map(x -> get_cached(array_bank, x), args) Main.broadcasted(f, gargs...) @@ -111,7 +95,7 @@ get_cached(array_bank, t::Union{Type, Function, Broadcasted, Symbol, Module, Not get_cached(array_bank, t::Union{Tuple, NamedTuple}) = map(get_cached, t) function get_cached(x::T) where T - T <: AbstractArray && return get_cached(array_bank, x) + T <: Array && return get_cached(array_bank, x) isstructtype(T) && return x # get_cached(obs, x) get_cached(array_bank, x) end From 01010c5d6a10f829053ae79c0b070c6ea2ec0a2c Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 16 Jul 2019 11:22:59 -0400 Subject: [PATCH 19/32] array keys --- src/context.jl | 8 +++++++- test/context.jl | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/context.jl b/src/context.jl index 78d3cdad..c086e14f 100644 --- a/src/context.jl +++ b/src/context.jl @@ -10,6 +10,12 @@ import Base.Broadcast.Broadcasted # TODO: this should be a context const array_bank = WeakKeyDict{Array,CuArray}() +function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N} + cpu = Array{T,N}(undef, ntuple(_->0,N)) + cx[cpu] = x + return cpu +end + # Hold all the objects related to the op # obs = IdDict() @@ -17,7 +23,7 @@ for f in (:+, :-, :*, :/) @eval function cuize(::typeof($f), a::Array, b::Array) ga = get_cached(array_bank, a) gb = get_cached(array_bank, b) - c = $f(ga, gb) + cache(array_bank, $f(ga, gb)) end end diff --git a/test/context.jl b/test/context.jl index c9fa7fdf..5e37e1c0 100644 --- a/test/context.jl +++ b/test/context.jl @@ -5,4 +5,4 @@ b = rand(5) @test cuda() do W*b -end ≈ W*b +end isa Array From 65d347cbf451e1b0af86f10e688ed7355d8af5ee Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 16 Jul 2019 13:11:54 -0400 Subject: [PATCH 20/32] refill arrays --- src/context.jl | 27 +++++++++++++++++++++++---- test/context.jl | 4 +--- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/context.jl b/src/context.jl index c086e14f..cbf6c3a1 100644 --- a/src/context.jl +++ b/src/context.jl @@ -5,10 +5,19 @@ import Base.Broadcast.broadcasted import Base.Broadcast.materialize import Base.Broadcast.Broadcasted +function _resize!(a::Array, sz::NTuple{<:Any,Integer}) + ccall(:jl_array_grow_end, Cvoid, (Any, UInt), a, prod(sz)) + ptr = convert(Ptr{Csize_t},pointer_from_objref(a)) + for i = 1:length(sz) + unsafe_store!(ptr+8*(i+2), sz[i]) + end + return a +end -# Hold all the arrays related to the op -# TODO: this should be a context -const array_bank = WeakKeyDict{Array,CuArray}() +function refill!(a::Array, b::CuArray) + _resize!(a, size(b)) + copy!(a, b) +end function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N} cpu = Array{T,N}(undef, ntuple(_->0,N)) @@ -133,6 +142,16 @@ function cuize(::typeof(invoke), f::T, types, args...) where T invoke(gf, types, map(get_cached, args)...) end +# Hold all the arrays related to the op +# TODO: this should be a context +const array_bank = IdDict{Array,CuArray}() + function cuda(f) - cuize(f) + out = cuize(f) + for (x, cx) in array_bank + length(x) == length(cx) && continue + refill!(x, cx) + end + empty!(array_bank) + return out end diff --git a/test/context.jl b/test/context.jl index 5e37e1c0..24e9cccb 100644 --- a/test/context.jl +++ b/test/context.jl @@ -3,6 +3,4 @@ using CuArrays, Test W = rand(5, 5) b = rand(5) -@test cuda() do - W*b -end isa Array +@test cuda(() -> W*b) ≈ W*b From 815d49ba28c4e7c56b7b07675bf58ee7b4c0fbba Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 20 Aug 2019 00:19:14 +0530 Subject: [PATCH 21/32] use context/ rm iddict --- src/context.jl | 155 +++++++++++++++++++++++++++---------------------- 1 file changed, 84 insertions(+), 71 deletions(-) diff --git a/src/context.jl b/src/context.jl index cbf6c3a1..347c52d5 100644 --- a/src/context.jl +++ b/src/context.jl @@ -1,10 +1,33 @@ -using IRTools: isexpr, IR, @dynamo -using IRTools: meta, Pipe, finish, Variable +using IRTools: isexpr, IR, @dynamo, postwalk +using IRTools: meta, Pipe, finish, Variable, self +using MacroTools: @forward import Base.Broadcast.broadcasted import Base.Broadcast.materialize import Base.Broadcast.Broadcasted +struct IRCtx{T,K} + array_bank::IdDict{T,K} + + IRCtx() = new{Array, CuArray}(IdDict{Array,CuArray}()) +end + +# Display fns for debugging, remove before committing +function Base.summary(io::IO, c::IRCtx) + print(io, "IR Context for CUDA ") + summary(io, c.array_bank) +end + +function Base.show(io::IO, c::IRCtx{T,K}) where {T,K} + print(io, "IR Context for CUDA ") + display(c.array_bank) +end + +@forward IRCtx.array_bank Base.getindex, Base.iterate, + Base.setindex!, Base.empty!, + Base.length, + Base.first, Base.last, Base.haskey + function _resize!(a::Array, sz::NTuple{<:Any,Integer}) ccall(:jl_array_grow_end, Cvoid, (Any, UInt), a, prod(sz)) ptr = convert(Ptr{Csize_t},pointer_from_objref(a)) @@ -25,14 +48,12 @@ function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N} return cpu end -# Hold all the objects related to the op -# obs = IdDict() - for f in (:+, :-, :*, :/) - @eval function cuize(::typeof($f), a::Array, b::Array) + @eval function (c::IRCtx)(::typeof($f), a::AbstractArray, b::AbstractArray) ga = get_cached(array_bank, a) gb = get_cached(array_bank, b) - cache(array_bank, $f(ga, gb)) + # cache(array_bank, $f(ga, gb)) + $f(ga, gb) end end @@ -42,26 +63,38 @@ function get_cached(array_bank, arr::Array{T,N})::CuArray{T,N} where {T,N} (array_bank[arr] = CuArray(arr)) end -function cuize(::typeof(broadcasted), f, args...) +function (c::IRCtx)(::typeof(broadcasted), f, args...) gargs = map(x -> get_cached(array_bank, x), args) - Main.broadcasted(f, gargs...) + broadcasted(f, gargs...) end -function cuize(::typeof(getproperty), o, s::Symbol) +function (c::IRCtx)(::typeof(getproperty), o, s::Symbol) getproperty(o, s) |> get_cached end -# function cuize(::typeof(getproperty), o, s::Symbol) -# getproperty(get_cached(o), s) -# end - -function cuize(::typeof(broadcast), f, args...) - @show f +function (c::IRCtx)(::typeof(broadcast), f, args...) gargs = map(x -> get_cached(array_bank, x), args) broadcast(f, gargs...) end -@dynamo function cuize(meta...) +function (c::IRCtx)(::typeof(getfield), o, s::Symbol) + getfield(o, s) |> get_cached +end + +function wrap_cuize(f) + @eval function (c::IRCtx)(::typeof($f), args...) + gargs = map(get_cached, args) + $f(gargs...) # use `cache` + end +end + +wrap_cuize.((sum, similar, )) + +function (c::IRCtx)(::typeof(reshape), arr, args...) + reshape(get_cached(arr), args...) +end + +@dynamo function (c::IRCtx)(meta...) meta == nothing && return ir = IR(meta...) ir == nothing && return @@ -71,44 +104,17 @@ end isexpr(st.expr, :call) || continue ex = st.expr - pr[v] = Expr(:call, GlobalRef(CuArrays, :cuize), ex.args...) + pr[v] = Expr(:call, self, ex.args...) end return finish(pr) end -cuize(::typeof(setindex!), ::Tuple, args...) = tuple(args...) - -################################################################### - -function children(x::T, fs = fieldnames(T)) where T - (; zip(fs, map(f -> get_cached(getproperty(x, f)), fs))...) -end - -children(x::Tuple) = map(get_cached, x) -# function children(s::T) where T <:AbstractSet -# t = T() -# for p in s -# push!(t, get_cached(p)) -# end -# t -# end - - -# mapchildren(x::T) where T = @eval $(Symbol(T.name))($(children(x))...) -mapchildren(x::T) where T = children(x) -# mapchildren(x::T) where T<:AbstractSet = children(x) - -function get_cached(obs, x::T) where T - haskey(obs, x) && return obs[x] - x isa CuArray && return x - - obs[x] = mapchildren(x) -end - -get_cached(array_bank, t::Union{Type, Function, Broadcasted, Symbol, Module, Nothing, Missing, Ptr, T}) where {T <: Real} = t +get_cached(array_bank, t::Union{Type, UnitRange, Function, Broadcasted, Symbol, Module, Nothing, Missing, Ptr, CuPtr, T}) where {T <: Real} = t get_cached(array_bank, t::Union{Tuple, NamedTuple}) = map(get_cached, t) +get_cached(array_bank, x::CuArray) = x + function get_cached(x::T) where T T <: Array && return get_cached(array_bank, x) isstructtype(T) && return x # get_cached(obs, x) @@ -116,42 +122,49 @@ function get_cached(x::T) where T end """ - Disable `cuize` for a function + Disable `IRCtx` for a function """ -function noop_pass(f, args...) - @eval cuize(::typeof($f), args...) = $f(args...) +function noop_pass(f) + @eval (c::IRCtx)(::typeof($f), args...) = $f(args...) end -noop_pass.((materialize, )) +noop_pass.((materialize, get_cached, NNlib.check_spdf, + )) -cuize(::typeof(get_cached), args...) = get_cached(args...) - -function makechildren(T::Type, nt::NamedTuple) - eval(nameof(T))(nt...) -end - - -# Functions called inside `cuize` aren't executed as part of the context -# So any assumptions made inside the context (`getproperty`, for eg) will -# not Hold -# Thus we need the actual objects when trying to call the objects, as opposed to -# continuing inside the context where we can pick fields up from a NamedTuple -# Without this limitation, we can avoid caching the structs themselves -function cuize(::typeof(invoke), f::T, types, args...) where T - gf = f isa Function ? f : makechildren(T, get_cached(obs, f)) - invoke(gf, types, map(get_cached, args)...) +for f in names(NNlib) + getfield(NNlib, f) isa Function || continue + @eval function (c::IRCtx)(::typeof($f), args...) + gargs = map(get_cached, args) + # cache(array_bank, $f(gargs...)) + $f(gargs...) + end end # Hold all the arrays related to the op # TODO: this should be a context -const array_bank = IdDict{Array,CuArray}() +# BitArray and friends would like an AbstractArray construct +# const array_bank = IdDict{Array,CuArray}() +const array_bank = IRCtx() + +""" + Creates a `cuda` context within which we travel + through the entire callstack to find matrix/vector + operations and try to offload them to a GPU. + + Example: + ``` + cuda() do + # do something + end + ``` +""" function cuda(f) - out = cuize(f) + out = array_bank(f) for (x, cx) in array_bank length(x) == length(cx) && continue refill!(x, cx) end - empty!(array_bank) + # empty!(array_bank) return out end From 7a4934ea7960619bcb095315a47539b808d0da7c Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Wed, 28 Aug 2019 15:11:01 +0100 Subject: [PATCH 22/32] simplify CUDACtx --- src/context.jl | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/context.jl b/src/context.jl index 347c52d5..4a7f0f85 100644 --- a/src/context.jl +++ b/src/context.jl @@ -6,24 +6,25 @@ import Base.Broadcast.broadcasted import Base.Broadcast.materialize import Base.Broadcast.Broadcasted -struct IRCtx{T,K} - array_bank::IdDict{T,K} - - IRCtx() = new{Array, CuArray}(IdDict{Array,CuArray}()) +# TODO use a WeakKeyDict +struct CUDACtx + array_bank::IdDict{Array,CuArray} end +CUDACtx() = CUDACtx(IdDict{Array,CuArray}()) + # Display fns for debugging, remove before committing -function Base.summary(io::IO, c::IRCtx) +function Base.summary(io::IO, c::CUDACtx) print(io, "IR Context for CUDA ") summary(io, c.array_bank) end -function Base.show(io::IO, c::IRCtx{T,K}) where {T,K} +function Base.show(io::IO, c::CUDACtx) print(io, "IR Context for CUDA ") display(c.array_bank) end -@forward IRCtx.array_bank Base.getindex, Base.iterate, +@forward CUDACtx.array_bank Base.getindex, Base.iterate, Base.setindex!, Base.empty!, Base.length, Base.first, Base.last, Base.haskey @@ -49,7 +50,7 @@ function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N} end for f in (:+, :-, :*, :/) - @eval function (c::IRCtx)(::typeof($f), a::AbstractArray, b::AbstractArray) + @eval function (c::CUDACtx)(::typeof($f), a::AbstractArray, b::AbstractArray) ga = get_cached(array_bank, a) gb = get_cached(array_bank, b) # cache(array_bank, $f(ga, gb)) @@ -63,26 +64,26 @@ function get_cached(array_bank, arr::Array{T,N})::CuArray{T,N} where {T,N} (array_bank[arr] = CuArray(arr)) end -function (c::IRCtx)(::typeof(broadcasted), f, args...) +function (c::CUDACtx)(::typeof(broadcasted), f, args...) gargs = map(x -> get_cached(array_bank, x), args) broadcasted(f, gargs...) end -function (c::IRCtx)(::typeof(getproperty), o, s::Symbol) +function (c::CUDACtx)(::typeof(getproperty), o, s::Symbol) getproperty(o, s) |> get_cached end -function (c::IRCtx)(::typeof(broadcast), f, args...) +function (c::CUDACtx)(::typeof(broadcast), f, args...) gargs = map(x -> get_cached(array_bank, x), args) broadcast(f, gargs...) end -function (c::IRCtx)(::typeof(getfield), o, s::Symbol) +function (c::CUDACtx)(::typeof(getfield), o, s::Symbol) getfield(o, s) |> get_cached end function wrap_cuize(f) - @eval function (c::IRCtx)(::typeof($f), args...) + @eval function (c::CUDACtx)(::typeof($f), args...) gargs = map(get_cached, args) $f(gargs...) # use `cache` end @@ -90,11 +91,11 @@ end wrap_cuize.((sum, similar, )) -function (c::IRCtx)(::typeof(reshape), arr, args...) +function (c::CUDACtx)(::typeof(reshape), arr, args...) reshape(get_cached(arr), args...) end -@dynamo function (c::IRCtx)(meta...) +@dynamo function (c::CUDACtx)(meta...) meta == nothing && return ir = IR(meta...) ir == nothing && return @@ -122,10 +123,10 @@ function get_cached(x::T) where T end """ - Disable `IRCtx` for a function + Disable `CUDACtx` for a function """ function noop_pass(f) - @eval (c::IRCtx)(::typeof($f), args...) = $f(args...) + @eval (c::CUDACtx)(::typeof($f), args...) = $f(args...) end noop_pass.((materialize, get_cached, NNlib.check_spdf, @@ -133,7 +134,7 @@ noop_pass.((materialize, get_cached, NNlib.check_spdf, for f in names(NNlib) getfield(NNlib, f) isa Function || continue - @eval function (c::IRCtx)(::typeof($f), args...) + @eval function (c::CUDACtx)(::typeof($f), args...) gargs = map(get_cached, args) # cache(array_bank, $f(gargs...)) $f(gargs...) @@ -145,7 +146,7 @@ end # BitArray and friends would like an AbstractArray construct # const array_bank = IdDict{Array,CuArray}() -const array_bank = IRCtx() +const array_bank = CUDACtx() """ Creates a `cuda` context within which we travel From 50eaec920a596b81c29d1fab2910efc5aff05ef4 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 29 Aug 2019 19:44:13 +0530 Subject: [PATCH 23/32] add basic tests --- test/context.jl | 46 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/test/context.jl b/test/context.jl index 24e9cccb..c0191802 100644 --- a/test/context.jl +++ b/test/context.jl @@ -1,6 +1,46 @@ using CuArrays, Test +using CuArrays.NNlib -W = rand(5, 5) -b = rand(5) +# Check simple ops work and broadcast +@testset "simple ops" begin + W = rand(5, 5) + b = rand(5) + @test cuda(() -> W*b) ≈ W*b -@test cuda(() -> W*b) ≈ W*b + a = rand(10) + b = rand(10) + + r = cuda() do + a + b + end + @test r isa Array + + r = cuda() do + a .+ b + end + @test r isa Array +end + +# Check that functions happen +@testset "linear" begin + linear(x, W, b) = (x * W) .+ b + w = rand(10, 10) + b = zeros(10) + x = rand(10,10) + r = cuda() do + linear(x, w, b) + end + @test r isa Array{Float32} +end + +# check that NNlib is wrapped correctly +@testset "conv Context" begin + w = rand(Float32, 3, 3, 3, 16) + r = rand(Float32, 32, 32, 3, 1) + g = cuda() do + conv(r, w) + end + g = conv(r, w) + @test c ≈ g + @test g isa Array +end From 8515b7f8e0863de3c3fd622686f6c1fc34bd2d67 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 29 Aug 2019 19:46:42 +0530 Subject: [PATCH 24/32] use to return cpu arrays --- src/context.jl | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/context.jl b/src/context.jl index 4a7f0f85..81b32f59 100644 --- a/src/context.jl +++ b/src/context.jl @@ -48,13 +48,13 @@ function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N} cx[cpu] = x return cpu end +cache(cx, f) = f for f in (:+, :-, :*, :/) @eval function (c::CUDACtx)(::typeof($f), a::AbstractArray, b::AbstractArray) ga = get_cached(array_bank, a) gb = get_cached(array_bank, b) - # cache(array_bank, $f(ga, gb)) - $f(ga, gb) + cache(array_bank, $f(ga, gb)) end end @@ -66,7 +66,7 @@ end function (c::CUDACtx)(::typeof(broadcasted), f, args...) gargs = map(x -> get_cached(array_bank, x), args) - broadcasted(f, gargs...) + broadcasted(f, gargs...) |> x -> cache(array_bank, x) end function (c::CUDACtx)(::typeof(getproperty), o, s::Symbol) @@ -75,7 +75,7 @@ end function (c::CUDACtx)(::typeof(broadcast), f, args...) gargs = map(x -> get_cached(array_bank, x), args) - broadcast(f, gargs...) + broadcast(f, gargs...) |> x -> cache(array_bank, x) end function (c::CUDACtx)(::typeof(getfield), o, s::Symbol) @@ -85,14 +85,15 @@ end function wrap_cuize(f) @eval function (c::CUDACtx)(::typeof($f), args...) gargs = map(get_cached, args) - $f(gargs...) # use `cache` + cache(array_bank, $f(gargs...)) end end -wrap_cuize.((sum, similar, )) +wrap_cuize.((sum, similar, materialize)) function (c::CUDACtx)(::typeof(reshape), arr, args...) - reshape(get_cached(arr), args...) + r = reshape(get_cached(arr), args...) + cache(array_bank, r) end @dynamo function (c::CUDACtx)(meta...) @@ -111,6 +112,7 @@ end return finish(pr) end +# TODO: remove arbitrary things in one place get_cached(array_bank, t::Union{Type, UnitRange, Function, Broadcasted, Symbol, Module, Nothing, Missing, Ptr, CuPtr, T}) where {T <: Real} = t get_cached(array_bank, t::Union{Tuple, NamedTuple}) = map(get_cached, t) @@ -118,7 +120,7 @@ get_cached(array_bank, x::CuArray) = x function get_cached(x::T) where T T <: Array && return get_cached(array_bank, x) - isstructtype(T) && return x # get_cached(obs, x) + isstructtype(T) && return x get_cached(array_bank, x) end @@ -129,23 +131,19 @@ function noop_pass(f) @eval (c::CUDACtx)(::typeof($f), args...) = $f(args...) end -noop_pass.((materialize, get_cached, NNlib.check_spdf, +noop_pass.((get_cached, NNlib.check_spdf, )) for f in names(NNlib) getfield(NNlib, f) isa Function || continue @eval function (c::CUDACtx)(::typeof($f), args...) gargs = map(get_cached, args) - # cache(array_bank, $f(gargs...)) - $f(gargs...) + cache(array_bank, $f(gargs...)) end end # Hold all the arrays related to the op -# TODO: this should be a context # BitArray and friends would like an AbstractArray construct -# const array_bank = IdDict{Array,CuArray}() - const array_bank = CUDACtx() """ @@ -166,6 +164,6 @@ function cuda(f) length(x) == length(cx) && continue refill!(x, cx) end - # empty!(array_bank) + empty!(array_bank) return out end From be857099c63666623c9e2f8d91f37864405620b1 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 30 Aug 2019 01:23:08 +0530 Subject: [PATCH 25/32] cleanups --- src/context.jl | 87 ++++++++++++++++++++------------------------------ 1 file changed, 34 insertions(+), 53 deletions(-) diff --git a/src/context.jl b/src/context.jl index 81b32f59..2cceef2d 100644 --- a/src/context.jl +++ b/src/context.jl @@ -14,19 +14,19 @@ end CUDACtx() = CUDACtx(IdDict{Array,CuArray}()) # Display fns for debugging, remove before committing -function Base.summary(io::IO, c::CUDACtx) +function Base.summary(io::IO, cx::CUDACtx) print(io, "IR Context for CUDA ") - summary(io, c.array_bank) + summary(io, cx.array_bank) end -function Base.show(io::IO, c::CUDACtx) +function Base.show(io::IO, cx::CUDACtx) print(io, "IR Context for CUDA ") - display(c.array_bank) + display(cx.array_bank) end @forward CUDACtx.array_bank Base.getindex, Base.iterate, Base.setindex!, Base.empty!, - Base.length, + Base.length, Base.get! Base.first, Base.last, Base.haskey function _resize!(a::Array, sz::NTuple{<:Any,Integer}) @@ -51,53 +51,47 @@ end cache(cx, f) = f for f in (:+, :-, :*, :/) - @eval function (c::CUDACtx)(::typeof($f), a::AbstractArray, b::AbstractArray) - ga = get_cached(array_bank, a) - gb = get_cached(array_bank, b) - cache(array_bank, $f(ga, gb)) + @eval function (cx::CUDACtx)(::typeof($f), a::AbstractArray, b::AbstractArray) + ga = get_cached(cx, a) + gb = get_cached(cx, b) + cache(cx, $f(ga, gb)) end end -function get_cached(array_bank, arr::Array{T,N})::CuArray{T,N} where {T,N} - haskey(array_bank, arr) ? - array_bank[arr] : - (array_bank[arr] = CuArray(arr)) +function get_cached(cx::CUDACtx, arr::Array{T,N})::CuArray{T,N} where {T,N} + get!(cx, arr, CuArray(arr)) end +get_cached(cx::CUDACtx, x) = x -function (c::CUDACtx)(::typeof(broadcasted), f, args...) - gargs = map(x -> get_cached(array_bank, x), args) - broadcasted(f, gargs...) |> x -> cache(array_bank, x) +function (cx::CUDACtx)(::typeof(broadcasted), f, args...) + gargs = map(x -> get_cached(cx, x), args) + broadcasted(f, gargs...) |> x -> cache(cx, x) end -function (c::CUDACtx)(::typeof(getproperty), o, s::Symbol) - getproperty(o, s) |> get_cached +function (cx::CUDACtx)(::typeof(getproperty), o, s::Symbol) + op = getproperty(o, s) |> x -> get_cached(cx, x) end -function (c::CUDACtx)(::typeof(broadcast), f, args...) - gargs = map(x -> get_cached(array_bank, x), args) - broadcast(f, gargs...) |> x -> cache(array_bank, x) -end - -function (c::CUDACtx)(::typeof(getfield), o, s::Symbol) - getfield(o, s) |> get_cached +function (cx::CUDACtx)(::typeof(broadcast), f, args...) + gargs = map(x -> get_cached(cx, x), args) + broadcast(f, gargs...) |> x -> cache(cx, x) end function wrap_cuize(f) - @eval function (c::CUDACtx)(::typeof($f), args...) - gargs = map(get_cached, args) - cache(array_bank, $f(gargs...)) + @eval function (cx::CUDACtx)(::typeof($f), args...) + gargs = map(x -> get_cached(cx, x), args) + cache(cx, $f(gargs...)) end end wrap_cuize.((sum, similar, materialize)) -function (c::CUDACtx)(::typeof(reshape), arr, args...) - r = reshape(get_cached(arr), args...) - cache(array_bank, r) +function (cx::CUDACtx)(::typeof(reshape), arr, args...) + r = reshape(get_cached(cx, arr), args...) + cache(cx, r) end -@dynamo function (c::CUDACtx)(meta...) - meta == nothing && return +@dynamo function (cx::CUDACtx)(meta...) ir = IR(meta...) ir == nothing && return @@ -112,18 +106,6 @@ end return finish(pr) end -# TODO: remove arbitrary things in one place -get_cached(array_bank, t::Union{Type, UnitRange, Function, Broadcasted, Symbol, Module, Nothing, Missing, Ptr, CuPtr, T}) where {T <: Real} = t -get_cached(array_bank, t::Union{Tuple, NamedTuple}) = map(get_cached, t) - -get_cached(array_bank, x::CuArray) = x - -function get_cached(x::T) where T - T <: Array && return get_cached(array_bank, x) - isstructtype(T) && return x - get_cached(array_bank, x) -end - """ Disable `CUDACtx` for a function """ @@ -136,15 +118,14 @@ noop_pass.((get_cached, NNlib.check_spdf, for f in names(NNlib) getfield(NNlib, f) isa Function || continue - @eval function (c::CUDACtx)(::typeof($f), args...) - gargs = map(get_cached, args) - cache(array_bank, $f(gargs...)) + @eval function (cx::CUDACtx)(::typeof($f), args...) + gargs = map(x -> get_cached(cx, x), args) + cache(cx, $f(gargs...)) end end # Hold all the arrays related to the op # BitArray and friends would like an AbstractArray construct -const array_bank = CUDACtx() """ Creates a `cuda` context within which we travel @@ -158,12 +139,12 @@ const array_bank = CUDACtx() end ``` """ -function cuda(f) - out = array_bank(f) - for (x, cx) in array_bank +function cuda(f, ctx = CUDACtx()) + out = ctx(f) + for (x, cx) in ctx length(x) == length(cx) && continue refill!(x, cx) end - empty!(array_bank) + empty!(ctx) return out end From 289fb46d930562fa2c07cb5284b6a7fb1bc0712f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 30 Aug 2019 01:31:33 +0530 Subject: [PATCH 26/32] remove getproperty fn --- src/context.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/context.jl b/src/context.jl index 2cceef2d..66a09653 100644 --- a/src/context.jl +++ b/src/context.jl @@ -68,10 +68,6 @@ function (cx::CUDACtx)(::typeof(broadcasted), f, args...) broadcasted(f, gargs...) |> x -> cache(cx, x) end -function (cx::CUDACtx)(::typeof(getproperty), o, s::Symbol) - op = getproperty(o, s) |> x -> get_cached(cx, x) -end - function (cx::CUDACtx)(::typeof(broadcast), f, args...) gargs = map(x -> get_cached(cx, x), args) broadcast(f, gargs...) |> x -> cache(cx, x) From 1d10f6269d40d1aacc3e26a006c3dde20df052b6 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 30 Aug 2019 01:47:17 +0530 Subject: [PATCH 27/32] context -> contextual --- Project.toml | 15 +++++++-------- src/CuArrays.jl | 2 +- src/{context.jl => contextual.jl} | 0 test/{context.jl => contextual.jl} | 0 test/runtests.jl | 2 +- 5 files changed, 9 insertions(+), 10 deletions(-) rename src/{context.jl => contextual.jl} (100%) rename test/{context.jl => contextual.jl} (100%) diff --git a/Project.toml b/Project.toml index 89ced327..933620cd 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,13 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test", "FFTW", "ForwardDiff"] + [compat] Adapt = "1.0" CUDAapi = "0.5.3, 0.6, 1.0" @@ -27,11 +34,3 @@ CUDAnative = "2.0" GPUArrays = "0.7.1, 1.0" NNlib = "0.6" julia = "1.0" - -[extras] -FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test", "FFTW", "ForwardDiff"] diff --git a/src/CuArrays.jl b/src/CuArrays.jl index 0ab24915..9cc939f4 100644 --- a/src/CuArrays.jl +++ b/src/CuArrays.jl @@ -81,7 +81,7 @@ include("dnn/CUDNN.jl") include("nnlib.jl") -include("context.jl") +include("contextual.jl") include("deprecated.jl") diff --git a/src/context.jl b/src/contextual.jl similarity index 100% rename from src/context.jl rename to src/contextual.jl diff --git a/test/context.jl b/test/contextual.jl similarity index 100% rename from test/context.jl rename to test/contextual.jl diff --git a/test/runtests.jl b/test/runtests.jl index dd2edfa4..cb7e593a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,7 +26,7 @@ include("solver.jl") include("sparse_solver.jl") include("dnn.jl") include("forwarddiff.jl") -include("context.jl") +include("contextual.jl") CuArrays.pool_status() CuArrays.pool_timings() From 2a342cad320693215dc2921ae511a3266431d099 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 30 Aug 2019 01:49:48 +0530 Subject: [PATCH 28/32] fixes --- Project.toml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 933620cd..1613a6b4 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +[extras] FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -27,10 +28,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" test = ["Test", "FFTW", "ForwardDiff"] [compat] -Adapt = "1.0" -CUDAapi = "0.5.3, 0.6, 1.0" -CUDAdrv = "3.0" +julia = "1.0" CUDAnative = "2.0" -GPUArrays = "0.7.1, 1.0" +CUDAdrv = "3.0" +CUDAapi = "0.5.3, 0.6, 1.0" NNlib = "0.6" -julia = "1.0" +GPUArrays = "0.7.1, 1.0" +Adapt = "1.0" From c07fa95da0d9c626cfcd6850f1fd73cf83d539d2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 19 Sep 2019 22:21:38 +0530 Subject: [PATCH 29/32] move to macros --- src/contextual.jl | 46 ++++++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/src/contextual.jl b/src/contextual.jl index 66a09653..b214fc29 100644 --- a/src/contextual.jl +++ b/src/contextual.jl @@ -50,14 +50,7 @@ function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N} end cache(cx, f) = f -for f in (:+, :-, :*, :/) - @eval function (cx::CUDACtx)(::typeof($f), a::AbstractArray, b::AbstractArray) - ga = get_cached(cx, a) - gb = get_cached(cx, b) - cache(cx, $f(ga, gb)) - end -end - +# TODO: BitArray and friends would like an AbstractArray construct function get_cached(cx::CUDACtx, arr::Array{T,N})::CuArray{T,N} where {T,N} get!(cx, arr, CuArray(arr)) end @@ -68,19 +61,24 @@ function (cx::CUDACtx)(::typeof(broadcasted), f, args...) broadcasted(f, gargs...) |> x -> cache(cx, x) end -function (cx::CUDACtx)(::typeof(broadcast), f, args...) - gargs = map(x -> get_cached(cx, x), args) - broadcast(f, gargs...) |> x -> cache(cx, x) -end +macro wrap_cuize(fs...) + ex = Expr[] + for f in fs + q = quote + function (cx::CUDACtx)(::typeof($f), args...) + gargs = map(x -> get_cached(cx, x), args) + cache(cx, $f(gargs...)) + end + end + push!(ex, q) + end -function wrap_cuize(f) - @eval function (cx::CUDACtx)(::typeof($f), args...) - gargs = map(x -> get_cached(cx, x), args) - cache(cx, $f(gargs...)) + quote + $(ex...) end end -wrap_cuize.((sum, similar, materialize)) +@wrap_cuize :+ :- :* :/ sum similar materialize function (cx::CUDACtx)(::typeof(reshape), arr, args...) r = reshape(get_cached(cx, arr), args...) @@ -105,12 +103,15 @@ end """ Disable `CUDACtx` for a function """ -function noop_pass(f) - @eval (c::CUDACtx)(::typeof($f), args...) = $f(args...) +macro noop_pass(fs...) + ex = [:( (cx::CUDACtx)(::typeof($f), args...) = $f(args...) ) for f in fs] + + quote + $(ex...) + end end -noop_pass.((get_cached, NNlib.check_spdf, - )) +@noop_pass get_cached NNlib.check_spdf for f in names(NNlib) getfield(NNlib, f) isa Function || continue @@ -120,9 +121,6 @@ for f in names(NNlib) end end -# Hold all the arrays related to the op -# BitArray and friends would like an AbstractArray construct - """ Creates a `cuda` context within which we travel through the entire callstack to find matrix/vector From e67dd220d96fac7f6dfe8dd955f39825bb72db56 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 20 Sep 2019 02:14:14 +0530 Subject: [PATCH 30/32] get rid of broadcasted def --- src/contextual.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/contextual.jl b/src/contextual.jl index b214fc29..4b553366 100644 --- a/src/contextual.jl +++ b/src/contextual.jl @@ -56,9 +56,9 @@ function get_cached(cx::CUDACtx, arr::Array{T,N})::CuArray{T,N} where {T,N} end get_cached(cx::CUDACtx, x) = x -function (cx::CUDACtx)(::typeof(broadcasted), f, args...) +function (cx::CUDACtx)(::typeof(Base._mapreducedim!), f, op, args...) gargs = map(x -> get_cached(cx, x), args) - broadcasted(f, gargs...) |> x -> cache(cx, x) + Base._mapreducedim!(f, op, gargs...) |> x-> cache(cx, x) end macro wrap_cuize(fs...) @@ -121,6 +121,14 @@ for f in names(NNlib) end end +for f in names(LinearAlgebra) + getfield(LinearAlgebra, f) isa Function || continue + @eval function (cx::CUDACtx)(::typeof($f), args...) + gargs = map(x -> get_cached(cx, x), args) + cache(cx, $f(gargs...)) + end +end + """ Creates a `cuda` context within which we travel through the entire callstack to find matrix/vector From 0c681fb7293c9e76b004e022a885444a4f940738 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 3 Oct 2019 11:41:12 +0530 Subject: [PATCH 31/32] check output type in test --- test/contextual.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/contextual.jl b/test/contextual.jl index c0191802..7c9c2620 100644 --- a/test/contextual.jl +++ b/test/contextual.jl @@ -5,7 +5,9 @@ using CuArrays.NNlib @testset "simple ops" begin W = rand(5, 5) b = rand(5) - @test cuda(() -> W*b) ≈ W*b + op = cuda(() -> W*b) + @test op ≈ W*b + @test op isa Array a = rand(10) b = rand(10) From 173dea946c63d4e230a3b9d0d13380ec32951ec6 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 4 Oct 2019 12:35:26 +0530 Subject: [PATCH 32/32] better macro name --- src/contextual.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contextual.jl b/src/contextual.jl index 4b553366..f2bc5f62 100644 --- a/src/contextual.jl +++ b/src/contextual.jl @@ -61,7 +61,7 @@ function (cx::CUDACtx)(::typeof(Base._mapreducedim!), f, op, args...) Base._mapreducedim!(f, op, gargs...) |> x-> cache(cx, x) end -macro wrap_cuize(fs...) +macro contextual(fs...) ex = Expr[] for f in fs q = quote @@ -78,7 +78,7 @@ macro wrap_cuize(fs...) end end -@wrap_cuize :+ :- :* :/ sum similar materialize +@contextual :+ :- :* :/ sum similar materialize function (cx::CUDACtx)(::typeof(reshape), arr, args...) r = reshape(get_cached(cx, arr), args...)