From 745e4ee22bcbcb77283f4f664ad202ebeafd7863 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 15 Sep 2021 10:16:58 -0400 Subject: [PATCH 1/9] add split_bc_rule --- Project.toml | 2 +- src/stage1/broadcast.jl | 86 ++++++++++++++++++++++++++++++++++++++++- test/runtests.jl | 9 +++++ 3 files changed, 95 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 5af68933..6a0a56fa 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] ChainRules = "1.5" -ChainRulesCore = "1.2" +ChainRulesCore = "1.4" Combinatorics = "1" StaticArrays = "1" StatsBase = "0.33" diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 88fc0b39..8ae75824 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -34,6 +34,90 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} ∂⃖ₙ(map, f, a) end +using ChainRulesCore: derivatives_given_output + +(::∂⃖{1})(::typeof(broadcasted), f, args...) = split_bc_rule(f, args...) +(::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity +function split_bc_rule(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if T == Bool && Base.issingletontype(F) + # Trivial case + back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) + return f.(args...), back_1 + elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type( + derivatives_given_output, Tuple{T, F, map(eltype, args)...})) + # Fast path: just broadcast, and use x & y to find derivative. + ys = f.(args...) + # println("2") + function back_2(dys) + deltas = splitcast(unthunk(dys), ys, args...) do dy, y, as... + das = only(derivatives_given_output(y, f, as...)) + map(da -> dy * conj(da), das) + end + dargs = map(unbroadcast, args, deltas) + (NoTangent(), NoTangent(), dargs...) + end + return ys, back_2 + else + # Slow path: collect all the pullbacks & apply them later. + # println("3") + ys, backs = splitcast(rrule_via_ad, DiffractorRuleConfig(), f, args...) + function back_3(dys) + deltas = splitmap(backs, unthunk(dys)) do back, dy + map(unthunk, back(dy)) + end + dargs = map(unbroadcast, args, Base.tail(deltas)) # no real need to close over args here + (NoTangent(), sum(first(deltas)), dargs...) + end + return ys, back_3 + end +end + +using StructArrays +splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple() +splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...)))) + +unbroadcast(f::Function, x̄) = accum_sum(x̄) +unbroadcast(::Val, _) = NoTangent() +accum_sum(xs::AbstractArray{<:NoTangent}; dims = :) = NoTangent() + +#= + +julia> xs = randn(10_000); +julia> @btime Zygote.gradient(x -> sum(abs2, x), $xs) + 4.744 μs (2 allocations: 78.17 KiB) +julia> @btime Diffractor.unthunk.(gradient(x -> sum(abs2, x), $xs)); + 3.307 μs (2 allocations: 78.17 KiB) + +# Simple function + +julia> @btime Zygote.gradient(x -> sum(abs2, exp.(x)), $xs); + 72.541 μs (29 allocations: 391.47 KiB) # with dual numbers -- like 4 copies + +julia> @btime gradient(x -> sum(abs2, exp.(x)), $xs); + 45.875 μs (36 allocations: 235.47 KiB) # fast path -- one copy forward, one back + 44.042 μs (32 allocations: 313.48 KiB) # slow path -- 3 copies, extra is closure? + 61.167 μs (12 allocations: 703.41 KiB) # with `map` rule as before -- worse + +# Composed function, Zygote struggles + +julia> @btime Zygote.gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs); + 97.167 μs (29 allocations: 391.61 KiB) # with dual numbers (Zygote master) + 93.238 ms (849567 allocations: 19.22 MiB) # without, thus Zygote.pullback + +julia> @btime gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs); + 55.290 ms (830060 allocations: 49.75 MiB) # slow path + 14.747 ms (240043 allocations: 7.25 MiB) # with `map` rule as before -- better! + +# Compare unfused + +julia> @btime gradient(x -> sum(abs2, identity.(cbrt.(x))), $xs); + 69.458 μs (50 allocations: 392.09 KiB) # fast path -- two copies forward, two back + 75.041 μs (46 allocations: 470.11 KiB) # slow path -- 5 copies + 135.541 μs (27 allocations: 1.30 MiB) # with `map` rule as before -- worse + +=# + # The below is from Zygote: TODO: DO we want to do something better here? accum_sum(xs::Nothing; dims = :) = NoTangent() @@ -70,4 +154,4 @@ ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y, - z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end \ No newline at end of file + z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end diff --git a/test/runtests.jl b/test/runtests.jl index 51335369..d70bec5b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -214,5 +214,14 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test z45 ≈ 2.0 @test delta45 ≈ 1.0 +# Broadcasting +@test gradient(x -> sum(x ./ x), [1,2,3]) == ([1,1,1],) +@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) # derivatives_given_output +@test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # stores pullback +@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) +@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) +@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool shortcut +@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent()) + # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) #include("pinn.jl") From b484a14b518ef4411d8dbb4d55a75912b0d01cec Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 16 Sep 2021 14:11:54 -0400 Subject: [PATCH 2/9] lazier +,-,* rules --- src/stage1/broadcast.jl | 74 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 8 deletions(-) diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 8ae75824..f9ed110b 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -44,7 +44,8 @@ function split_bc_rule(f::F, args...) where {F} # Trivial case back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) return f.(args...), back_1 - elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type( + # elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type( + elseif isconcretetype(Core.Compiler._return_type( derivatives_given_output, Tuple{T, F, map(eltype, args)...})) # Fast path: just broadcast, and use x & y to find derivative. ys = f.(args...) @@ -79,6 +80,7 @@ splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiat unbroadcast(f::Function, x̄) = accum_sum(x̄) unbroadcast(::Val, _) = NoTangent() +unbroadcast(x::AbstractArray, x̄::NoTangent) = NoTangent() accum_sum(xs::AbstractArray{<:NoTangent}; dims = :) = NoTangent() #= @@ -99,6 +101,7 @@ julia> @btime gradient(x -> sum(abs2, exp.(x)), $xs); 44.042 μs (32 allocations: 313.48 KiB) # slow path -- 3 copies, extra is closure? 61.167 μs (12 allocations: 703.41 KiB) # with `map` rule as before -- worse + # Composed function, Zygote struggles julia> @btime Zygote.gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs); @@ -116,6 +119,27 @@ julia> @btime gradient(x -> sum(abs2, identity.(cbrt.(x))), $xs); 75.041 μs (46 allocations: 470.11 KiB) # slow path -- 5 copies 135.541 μs (27 allocations: 1.30 MiB) # with `map` rule as before -- worse + +# Lazy +,-,* for partial fusing + +julia> @btime Zygote.gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs); + 81.250 μs (21 allocations: 625.47 KiB) # special rules + dual numbers, 4 more copies than best + +julia> @btime gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs); + 57.166 μs (49 allocations: 470.22 KiB) # broadcast in *, - + 54.583 μs (46 allocations: 314.06 KiB) # broadcasted -- two less copies + 72.958 μs (26 allocations: 1016.38 KiB) # with `map` rule as before + +julia> gradient((x,y) -> sum(abs2, exp.(2 .* x .+ y)), xs, (rand(10)')) +ERROR: MethodError: no method matching size(::Base.Broadcast.Broadcasted # hmm + +julia> @btime gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)')); + 7.127 ms (75 allocations: 22.97 MiB) # after + 12.956 ms (57 allocations: 76.37 MiB) # before + +ulia> @btime Zygote.gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)')); + 9.937 ms (48 allocations: 45.86 MiB) + =# # The below is from Zygote: TODO: DO we want to do something better here? @@ -133,7 +157,7 @@ end trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) -unbroadcast(x::AbstractArray, x̄) = +unbroadcast(x::Union{AbstractArray, Base.Broadcast.Broadcasted}, x̄) = size(x) == size(x̄) ? x̄ : length(x) == length(x̄) ? trim(x, x̄) : trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) @@ -146,12 +170,46 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent() const Numeric = Union{Number, AbstractArray{<:Number, N} where N} -function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...) - broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...) +# function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...) +# broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...) +# end + +# Replace Zygote-like fully split broadcasting with one fused over easy operations +(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = split_bc_plus(arg) # ambiguity +function split_bc_plus(xs...) where {F} + broadcasted(+, xs...), Δ -> let Δun = unthunk(Δ) + # println("+") + (NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δun), xs)...) + end end +Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) = + mapreduce(eltype, promote_type, bc.args) # needed to hit fast path -ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y, - Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end +(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ) -ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y, - z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end +# ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y, +# Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end + +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y) + broadcasted(-, x, y), Δ -> let Δun = unthunk(Δ) + # println("-") + (NoTangent(), NoTangent(), unbroadcast(x, Δun), -unbroadcast(y, Δun)) + # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array + end +end + +# ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y, +# z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end + +using LinearAlgebra: dot + +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) + broadcasted(*, x, y), Δ -> let Δun = unthunk(Δ) + # println("*") + dx = x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y)) + dy = y isa Number ? dot(x, Δun) : unbroadcast(y, Δun .* conj.(x)) + # When x is an array but a smaller one, instead of dot you may be able to use mapreduce() + (NoTangent(), NoTangent(), dx, dy) + end +end From ebd17002179086ea6be6b6e58151b6572c5fadaf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 9 Jan 2022 15:13:33 -0500 Subject: [PATCH 3/9] update, rm comments --- Project.toml | 5 +- src/extra_rules.jl | 20 ++--- src/stage1/broadcast.jl | 164 +++++++++++++--------------------------- test/runtests.jl | 31 +++++++- 4 files changed, 94 insertions(+), 126 deletions(-) diff --git a/Project.toml b/Project.toml index 6a0a56fa..280e0b8d 100644 --- a/Project.toml +++ b/Project.toml @@ -8,13 +8,14 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] -ChainRules = "1.5" -ChainRulesCore = "1.4" +ChainRules = "1.17" +ChainRulesCore = "1.11" Combinatorics = "1" StaticArrays = "1" StatsBase = "0.33" diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 4d22c172..fcdb0a12 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -137,9 +137,10 @@ struct NonDiffOdd{N, O, P}; end # This should not happen (::NonDiffEven{N, O, O})(Δ...) where {N, O} = error() -@Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...) - Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}() -end +# WARNING: Method definition rrule(typeof(Core.apply_type), Any, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Core/core.jl:10 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:140. +# @Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...) +# Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}() +# end function ChainRulesCore.rrule(::typeof(Core.tuple), args...) Core.tuple(args...), Δ->Core.tuple(NoTangent(), Δ...) @@ -206,12 +207,13 @@ function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where { AT(x), Δ->(NoTangent(), Δ) end -function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...) - # We're leaving these in the eltype that the cotangent vector already has. - # There isn't really a good reason to believe we should convert to the - # original array type, so don't unless explicitly requested. - AT(undef, args...), Δ->(NoTangent(), NoTangent(), ntuple(_->NoTangent(), length(args))...) -end +# WARNING: Method definition rrule(Type{var"#s260"} where var"#s260"<:(Array{T, N} where N where T), UndefInitializer, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Base/array.jl:5 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:209. +# function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...) +# # We're leaving these in the eltype that the cotangent vector already has. +# # There isn't really a good reason to believe we should convert to the +# # original array type, so don't unless explicitly requested. +# AT(undef, args...), Δ->(NoTangent(), NoTangent(), ntuple(_->NoTangent(), length(args))...) +# end function unzip_tuple(t::Tuple) map(x->x[1], t), map(x->x[2], t) diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index f9ed110b..1704e111 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -29,8 +29,12 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)}, return r end +_print(s) = nothing +# _print(s) = printstyled(s, "\n"; color=:magenta) + # Broadcast over one element is just map function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} + _print("path 0") ∂⃖ₙ(map, f, a) end @@ -40,16 +44,16 @@ using ChainRulesCore: derivatives_given_output (::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity function split_bc_rule(f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) - if T == Bool && Base.issingletontype(F) + if T == Bool # Trivial case + _print("path 1") back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) return f.(args...), back_1 - # elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type( elseif isconcretetype(Core.Compiler._return_type( derivatives_given_output, Tuple{T, F, map(eltype, args)...})) # Fast path: just broadcast, and use x & y to find derivative. ys = f.(args...) - # println("2") + _print("path 2") function back_2(dys) deltas = splitcast(unthunk(dys), ys, args...) do dy, y, as... das = only(derivatives_given_output(y, f, as...)) @@ -61,7 +65,7 @@ function split_bc_rule(f::F, args...) where {F} return ys, back_2 else # Slow path: collect all the pullbacks & apply them later. - # println("3") + _print("path 3") ys, backs = splitcast(rrule_via_ad, DiffractorRuleConfig(), f, args...) function back_3(dys) deltas = splitmap(backs, unthunk(dys)) do back, dy @@ -78,108 +82,13 @@ using StructArrays splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple() splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...)))) -unbroadcast(f::Function, x̄) = accum_sum(x̄) -unbroadcast(::Val, _) = NoTangent() -unbroadcast(x::AbstractArray, x̄::NoTangent) = NoTangent() -accum_sum(xs::AbstractArray{<:NoTangent}; dims = :) = NoTangent() - -#= - -julia> xs = randn(10_000); -julia> @btime Zygote.gradient(x -> sum(abs2, x), $xs) - 4.744 μs (2 allocations: 78.17 KiB) -julia> @btime Diffractor.unthunk.(gradient(x -> sum(abs2, x), $xs)); - 3.307 μs (2 allocations: 78.17 KiB) - -# Simple function - -julia> @btime Zygote.gradient(x -> sum(abs2, exp.(x)), $xs); - 72.541 μs (29 allocations: 391.47 KiB) # with dual numbers -- like 4 copies - -julia> @btime gradient(x -> sum(abs2, exp.(x)), $xs); - 45.875 μs (36 allocations: 235.47 KiB) # fast path -- one copy forward, one back - 44.042 μs (32 allocations: 313.48 KiB) # slow path -- 3 copies, extra is closure? - 61.167 μs (12 allocations: 703.41 KiB) # with `map` rule as before -- worse - - -# Composed function, Zygote struggles - -julia> @btime Zygote.gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs); - 97.167 μs (29 allocations: 391.61 KiB) # with dual numbers (Zygote master) - 93.238 ms (849567 allocations: 19.22 MiB) # without, thus Zygote.pullback - -julia> @btime gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs); - 55.290 ms (830060 allocations: 49.75 MiB) # slow path - 14.747 ms (240043 allocations: 7.25 MiB) # with `map` rule as before -- better! - -# Compare unfused - -julia> @btime gradient(x -> sum(abs2, identity.(cbrt.(x))), $xs); - 69.458 μs (50 allocations: 392.09 KiB) # fast path -- two copies forward, two back - 75.041 μs (46 allocations: 470.11 KiB) # slow path -- 5 copies - 135.541 μs (27 allocations: 1.30 MiB) # with `map` rule as before -- worse - - -# Lazy +,-,* for partial fusing - -julia> @btime Zygote.gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs); - 81.250 μs (21 allocations: 625.47 KiB) # special rules + dual numbers, 4 more copies than best - -julia> @btime gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs); - 57.166 μs (49 allocations: 470.22 KiB) # broadcast in *, - - 54.583 μs (46 allocations: 314.06 KiB) # broadcasted -- two less copies - 72.958 μs (26 allocations: 1016.38 KiB) # with `map` rule as before - -julia> gradient((x,y) -> sum(abs2, exp.(2 .* x .+ y)), xs, (rand(10)')) -ERROR: MethodError: no method matching size(::Base.Broadcast.Broadcasted # hmm - -julia> @btime gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)')); - 7.127 ms (75 allocations: 22.97 MiB) # after - 12.956 ms (57 allocations: 76.37 MiB) # before - -ulia> @btime Zygote.gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)')); - 9.937 ms (48 allocations: 45.86 MiB) - -=# +# For certain cheap operations we can easily allow fused broadcast: -# The below is from Zygote: TODO: DO we want to do something better here? - -accum_sum(xs::Nothing; dims = :) = NoTangent() -accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent() -accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims) -accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims) -accum_sum(xs::Number; dims = :) = xs - -# https://github.com/FluxML/Zygote.jl/issues/594 -function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region) - Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)}) -end - -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) - -unbroadcast(x::Union{AbstractArray, Base.Broadcast.Broadcasted}, x̄) = - size(x) == size(x̄) ? x̄ : - length(x) == length(x̄) ? trim(x, x̄) : - trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) - -unbroadcast(x::Number, x̄) = accum_sum(x̄) -unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) -unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) - -unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent() - -const Numeric = Union{Number, AbstractArray{<:Number, N} where N} - -# function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...) -# broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...) -# end - -# Replace Zygote-like fully split broadcasting with one fused over easy operations (::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...) (::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = split_bc_plus(arg) # ambiguity function split_bc_plus(xs...) where {F} broadcasted(+, xs...), Δ -> let Δun = unthunk(Δ) - # println("+") + _print("broadcast +") (NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δun), xs)...) end end @@ -188,28 +97,61 @@ Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) = (::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ) -# ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y, -# Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end - function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y) broadcasted(-, x, y), Δ -> let Δun = unthunk(Δ) - # println("-") + _print("broadcast -") (NoTangent(), NoTangent(), unbroadcast(x, Δun), -unbroadcast(y, Δun)) # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array end end -# ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y, -# z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end - using LinearAlgebra: dot function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) broadcasted(*, x, y), Δ -> let Δun = unthunk(Δ) - # println("*") - dx = x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y)) - dy = y isa Number ? dot(x, Δun) : unbroadcast(y, Δun .* conj.(x)) + _print("broadcast *") + dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y)) + dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δun) : unbroadcast(y, Δun .* conj.(x)) # When x is an array but a smaller one, instead of dot you may be able to use mapreduce() + # Will things like this work? Ref([1,2]) .* [1,2,3] (NoTangent(), NoTangent(), dx, dy) end end + +(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) = + broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ))) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = + x, Δ -> (NoTangent(), Δ) + +function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) + N = ndims(dx) + if length(x) == length(dx) + ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors + else + # This is an awful hack to get type-stable `dims` + dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) + ProjectTo(x)(sum(dx; dims)) + end +end +unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::NoTangent) = NoTangent() + +unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx)) +unbroadcast(f::Function, df) = ProjectTo(x)(sum(df)) +unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx))) + +unbroadcast(::Bool, dx) = NoTangent() +unbroadcast(::AbstractArray{Bool}, dx) = NoTangent() +unbroadcast(::AbstractArray{Bool}, ::NoTangent) = NoTangent() # ambiguity +unbroadcast(::Val, dx) = NoTangent() +# Maybe more non-diff types? Some fallback? + +unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx))) +function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} + _print("unbroadcast tuple") + val = if length(x) == length(dx) + dx + else + sum(dx; dims=2:ndims(dx)) + end + ProjectTo(x)(NTuple{length(x)}(val)) # Tangent +end diff --git a/test/runtests.jl b/test/runtests.jl index d70bec5b..fa31ca63 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -215,13 +215,36 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test delta45 ≈ 1.0 # Broadcasting -@test gradient(x -> sum(x ./ x), [1,2,3]) == ([1,1,1],) -@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) # derivatives_given_output -@test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # stores pullback +@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output +@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) + +@test_broken gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # stores pullback +exp_log(x) = exp(log(x)) +@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) -@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool shortcut +@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure + +@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays +@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 +@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 + +@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output +@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),) @test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent()) +@test gradient(x -> sum(x .+ [1,2,3]), true) == (NoTangent(),) # Bool input +@test gradient(x -> sum(x ./ [1,2,3]), [true false]) == (NoTangent(),) +@test gradient(x -> sum(x .* [1,2,3]'), (true, false)) == (NoTangent(),) + +tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), [3,4,5]') +@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) +@test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] +@test tup_adj[2] isa Adjoint +@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal + +@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3]) # path 0, MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64}) +@test_broken gradient(x -> sum(gradient(x -> sum(x' .* x), x)[1]), [1,2,3]) == ([6,6,6],) # Control flow support not fully implemented yet for higher-order reverse mode +@test_broken gradient(x -> sum(gradient(x -> sum(x' ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) #include("pinn.jl") From 66f2b6a0f850b09f982c92d7a39e7e371c8b96a6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 10 Jan 2022 01:44:57 -0500 Subject: [PATCH 4/9] tidy, add more lazy cases --- src/stage1/broadcast.jl | 97 +++++++++++++++++++++++++++++++---------- test/runtests.jl | 4 ++ 2 files changed, 79 insertions(+), 22 deletions(-) diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 1704e111..607f4ef1 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -29,8 +29,12 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)}, return r end -_print(s) = nothing -# _print(s) = printstyled(s, "\n"; color=:magenta) +# Reverse mode broadcast rules + +using ChainRulesCore: derivatives_given_output + +# _print(s) = nothing +_print(s) = printstyled(s, "\n"; color=:magenta) # Broadcast over one element is just map function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} @@ -38,19 +42,17 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} ∂⃖ₙ(map, f, a) end -using ChainRulesCore: derivatives_given_output - (::∂⃖{1})(::typeof(broadcasted), f, args...) = split_bc_rule(f, args...) (::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity function split_bc_rule(f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) - if T == Bool - # Trivial case + TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...}) + if eltype(T) == Bool + # Trivial case: non-differentiable output _print("path 1") back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) return f.(args...), back_1 - elseif isconcretetype(Core.Compiler._return_type( - derivatives_given_output, Tuple{T, F, map(eltype, args)...})) + elseif T <: Number && isconcretetype(TΔ) # Fast path: just broadcast, and use x & y to find derivative. ys = f.(args...) _print("path 2") @@ -65,8 +67,9 @@ function split_bc_rule(f::F, args...) where {F} return ys, back_2 else # Slow path: collect all the pullbacks & apply them later. + # Since broadcast makes no guarantee about order, this does not bother to try to reverse it. _print("path 3") - ys, backs = splitcast(rrule_via_ad, DiffractorRuleConfig(), f, args...) + ys, backs = splitcast(∂⃖{1}(), f, args...) function back_3(dys) deltas = splitmap(backs, unthunk(dys)) do back, dy map(unthunk, back(dy)) @@ -78,8 +81,11 @@ function split_bc_rule(f::F, args...) where {F} end end +# This uses "mulltimap"-like constructs: + using StructArrays -splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple() +splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) +# warning: splitmap(identity, [1,2,3,4]) === NamedTuple() splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...)))) # For certain cheap operations we can easily allow fused broadcast: @@ -107,7 +113,7 @@ end using LinearAlgebra: dot -function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) # should this be vararg, or will laziness handle it? broadcasted(*, x, y), Δ -> let Δun = unthunk(Δ) _print("broadcast *") dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y)) @@ -117,41 +123,88 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) (NoTangent(), NoTangent(), dx, dy) end end +# Alternative to `x isa Number` etc above... but not quite right! +# (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y::Number) = rrule_via_ad(DiffractorRuleConfig(), *, x, y) + +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x, ::Val{2}) + _print("broadcast ^2") + broadcasted(*, x, x), Δ -> begin + dx = unbroadcast(x, 2 .* Δ .* conj.(x)) + (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) + end +end +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) + _print("simple ^2") + x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent()) +end + +# function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y) # not obvious whether this is better than automatic +# broadcasted(/, x, y), Δ -> let Δun = unthunk(Δ) +# _print("broadcast /") +# dx = unbroadcast(x, Δ ./ conj.(y)) +# dy = unbroadcast(y, .-Δ .* conj.(res ./ y)) +# (NoTangent(), NoTangent(), dx, dy) +# end +# end +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y::Number) + _print("simple /") + z, back = ∂⃖{1}()(/, x, y) + z, Δ -> begin + _, dx, dy = back(Δ) + (NoTangent(), NoTangent(), dx, dy) # maybe there should be a funciton for this? Use for conj, identity too + end +end (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) = broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ))) (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = x, Δ -> (NoTangent(), Δ) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = + x, Δ -> (NoTangent(), Δ) + +# All broadcasts use `unbroadcast` to reduce to correct shape: + function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) N = ndims(dx) if length(x) == length(dx) ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors else - # This is an awful hack to get type-stable `dims` - dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) + dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # awful hack to get type-stable `dims` ProjectTo(x)(sum(dx; dims)) end end unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::NoTangent) = NoTangent() +unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx))) +function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} + _print("unbroadcast tuple") + val = if length(x) == length(dx) + dx + else + sum(dx; dims=2:ndims(dx)) + end + ProjectTo(x)(NTuple{length(x)}(val)) # Tangent +end + +unbroadcast(f::Function, df) = sum(df) unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx)) -unbroadcast(f::Function, df) = ProjectTo(x)(sum(df)) unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx))) unbroadcast(::Bool, dx) = NoTangent() unbroadcast(::AbstractArray{Bool}, dx) = NoTangent() unbroadcast(::AbstractArray{Bool}, ::NoTangent) = NoTangent() # ambiguity unbroadcast(::Val, dx) = NoTangent() -# Maybe more non-diff types? Some fallback? -unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx))) -function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} - _print("unbroadcast tuple") - val = if length(x) == length(dx) - dx +function unbroadcast(x, dx) + p = ProjectTo(x) + if dx isa AbstractZero || p isa ProjectTo{<:AbstractZero} + return NoTangent() + end + b = Broadcast.broadcastable(x) + if b isa Ref # then x is scalar under broadcast + return p(sum(dx)) else - sum(dx; dims=2:ndims(dx)) + error("don't know how to handle broadcast gradient for x::$(typeof(x))") end - ProjectTo(x)(NTuple{length(x)}(val)) # Tangent end diff --git a/test/runtests.jl b/test/runtests.jl index fa31ca63..76a54dbc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -216,6 +216,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) # Broadcasting @test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output +@test gradient(x -> sum(sqrt.(atan.(x, x'))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) @test_broken gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # stores pullback @@ -229,6 +230,9 @@ exp_log(x) = exp(log(x)) @test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 +@test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],) +@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) + @test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output @test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),) @test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent()) From 63797fd550e3ddff2ab54ac9f2d283d93d8514c2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 10 Jan 2022 01:46:47 -0500 Subject: [PATCH 5/9] avoid StructArrays sometimes, add unzip --- src/stage1/broadcast.jl | 69 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 3 deletions(-) diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 607f4ef1..0b49da12 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -56,15 +56,22 @@ function split_bc_rule(f::F, args...) where {F} # Fast path: just broadcast, and use x & y to find derivative. ys = f.(args...) _print("path 2") - function back_2(dys) + function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all + delta = broadcast(unthunk(dys), ys, args...) do dy, y, a + das = only(derivatives_given_output(y, f, a)) + dy * conj(only(das)) + end + (NoTangent(), NoTangent(), unbroadcast(only(args), delta)) + end + function back_2_many(dys) deltas = splitcast(unthunk(dys), ys, args...) do dy, y, as... das = only(derivatives_given_output(y, f, as...)) map(da -> dy * conj(da), das) end - dargs = map(unbroadcast, args, deltas) + dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast? (NoTangent(), NoTangent(), dargs...) end - return ys, back_2 + return ys, length(args)==1 ? back_2_one : back_2_many else # Slow path: collect all the pullbacks & apply them later. # Since broadcast makes no guarantee about order, this does not bother to try to reverse it. @@ -88,6 +95,62 @@ splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args # warning: splitmap(identity, [1,2,3,4]) === NamedTuple() splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...)))) +#= +# This is how you could handle CuArrays, route them to unzip(map(...)) fallback path. +# Maybe 2nd derivatives too, to avoid writing a gradient for splitcast, rule for unzip is easy. + +function Diffractor.splitmap(f, args...) + if any(a -> a isa CuArray, args) + Diffractor._print("unzip splitmap") + unzip(map(f, args...)) + else + StructArrays.components(StructArray(Iterators.map(f, args...))) + end +end +function Diffractor.splitcast(f, args...) + if any(a -> a isa CuArray, args) + Diffractor._print("unzip splitcast") + unzip(broadcast(f, args...)) + else + StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...)))) + end +end + +gradient(x -> sum(log.(x) .+ x'), cu([1,2,3]))[1] +gradient(x -> sum(sqrt.(atan.(x, x'))), cu([1,2,3]))[1] + +=# + +function unzip(xs::AbstractArray) + x1 = first(xs) + x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples")) + N = length(x1) + unzip(xs, Val(N)) # like Zygote's unzip +end +@generated function unzip(xs, ::Val{N}) where {N} + each = [:(map($(Get(i)), xs)) for i in 1:N] + Expr(:tuple, each...) +end +unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy +@generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple} + each = if count(!Base.issingletontype, Ts.parameters) < 2 + # good case, no copy of data, some trivial arrays + [Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters] + else + [:(map($(Get(i)), xs)) for i in 1:length(fieldnames(Ts))] + end + Expr(:tuple, each...) +end + +struct Get{i} end +Get(i) = Get{Int(i)}() +(::Get{i})(x) where {i} = x[i] + +function ChainRulesCore.rrule(::typeof(unzip), xs::AbstractArray) + rezip(dy) = (NoTangent(), tuple.(unthunk(dy)...)) + return unzip(xs), rezip +end + # For certain cheap operations we can easily allow fused broadcast: (::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...) From 27328c2f9bbcc1282617469af99c60c1ee12a3a1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 13 Jan 2022 20:41:37 -0500 Subject: [PATCH 6/9] more on broadcasting --- src/extra_rules.jl | 18 +++++++++++------- src/stage1/broadcast.jl | 31 +++++++++++++++++++------------ src/stage1/generated.jl | 2 ++ test/runtests.jl | 4 +++- 4 files changed, 35 insertions(+), 20 deletions(-) diff --git a/src/extra_rules.jl b/src/extra_rules.jl index fcdb0a12..526fa629 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -147,16 +147,16 @@ function ChainRulesCore.rrule(::typeof(Core.tuple), args...) end # TODO: What to do about these integer rules -@ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type) +# @ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type) # now in CR 1.18 ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent() -# Skip AD'ing through the axis computation -function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted) - return Base.Broadcast.instantiate(bc), Δ->begin - Core.tuple(NoTangent(), Δ) - end -end +# # Skip AD'ing through the axis computation +# function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted) +# return Base.Broadcast.instantiate(bc), Δ->begin +# Core.tuple(NoTangent(), Δ) +# end +# end using StaticArrays @@ -268,3 +268,7 @@ end function ChainRulesCore.rrule(::Type{InplaceableThunk}, add!!, val) val, Δ->(NoTangent(), NoTangent(), Δ) end + +# ERROR: ArgumentError: Tangent for the primal Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}} should be backed by a AbstractDict type, not by NamedTuple{(:data,), Tuple{ChainRulesCore.ZeroTangent}}. +ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing # solves that! + diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 0b49da12..49333eac 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -44,7 +44,7 @@ end (::∂⃖{1})(::typeof(broadcasted), f, args...) = split_bc_rule(f, args...) (::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity -function split_bc_rule(f::F, args...) where {F} +function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N} T = Broadcast.combine_eltypes(f, args) TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...}) if eltype(T) == Bool @@ -71,10 +71,11 @@ function split_bc_rule(f::F, args...) where {F} dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast? (NoTangent(), NoTangent(), dargs...) end - return ys, length(args)==1 ? back_2_one : back_2_many + return ys, N==1 ? back_2_one : back_2_many else # Slow path: collect all the pullbacks & apply them later. - # Since broadcast makes no guarantee about order, this does not bother to try to reverse it. + # (Since broadcast makes no guarantee about order of calls, and un-fusing + # can change the number of calls, this does not bother to try to reverse.) _print("path 3") ys, backs = splitcast(∂⃖{1}(), f, args...) function back_3(dys) @@ -84,15 +85,21 @@ function split_bc_rule(f::F, args...) where {F} dargs = map(unbroadcast, args, Base.tail(deltas)) # no real need to close over args here (NoTangent(), sum(first(deltas)), dargs...) end + back_3(::AbstractZero) = (NoTangent(), map(Returns(ZeroTangent()), args)...) return ys, back_3 end end -# This uses "mulltimap"-like constructs: +# Skip AD'ing through the axis computation +function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted) + uninstantiate(Δ) = Core.tuple(NoTangent(), Δ) + return Base.Broadcast.instantiate(bc), uninstantiate +end + +# This uses "multimap"-like constructs: using StructArrays splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) -# warning: splitmap(identity, [1,2,3,4]) === NamedTuple() splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...)))) #= @@ -156,9 +163,9 @@ end (::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...) (::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = split_bc_plus(arg) # ambiguity function split_bc_plus(xs...) where {F} - broadcasted(+, xs...), Δ -> let Δun = unthunk(Δ) + broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw) _print("broadcast +") - (NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δun), xs)...) + (NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...) end end Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) = @@ -167,9 +174,9 @@ Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) = (::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ) function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y) - broadcasted(-, x, y), Δ -> let Δun = unthunk(Δ) + broadcasted(-, x, y), Δraw -> let Δ = unthunk(Δraw) _print("broadcast -") - (NoTangent(), NoTangent(), unbroadcast(x, Δun), -unbroadcast(y, Δun)) + (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)) # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array end end @@ -177,10 +184,10 @@ end using LinearAlgebra: dot function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) # should this be vararg, or will laziness handle it? - broadcasted(*, x, y), Δ -> let Δun = unthunk(Δ) + broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw) _print("broadcast *") - dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y)) - dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δun) : unbroadcast(y, Δun .* conj.(x)) + dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δ) : unbroadcast(x, Δ .* conj.(y)) + dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δ) : unbroadcast(y, Δ .* conj.(x)) # When x is an array but a smaller one, instead of dot you may be able to use mapreduce() # Will things like this work? Ref([1,2]) .* [1,2,3] (NoTangent(), NoTangent(), dx, dy) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 2f8f48fd..036c4710 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -49,6 +49,8 @@ end Base.getindex(o::OpticBundle, i::Int) = i == 1 ? o.x : i == 2 ? o.clos : throw(BoundsError(o, i)) +Base.lastindex(o::OpticBundle) = 2 + Base.iterate(o::OpticBundle) = (o.x, nothing) Base.iterate(o::OpticBundle, ::Nothing) = (o.clos, missing) Base.iterate(o::OpticBundle, ::Missing) = nothing diff --git a/test/runtests.jl b/test/runtests.jl index 76a54dbc..fa2d6085 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -231,7 +231,9 @@ exp_log(x) = exp(log(x)) @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 @test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],) -@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) +@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule +@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule +@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule @test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output @test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),) From 2ac91f0e2e27a65fe108d6ad60e94063fc837640 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 20 Jan 2022 01:21:19 -0500 Subject: [PATCH 7/9] more --- src/extra_rules.jl | 17 ++++----- src/stage1/broadcast.jl | 57 +++++++++++++---------------- test/runtests.jl | 79 ++++++++++++++++++++++------------------- 3 files changed, 76 insertions(+), 77 deletions(-) diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 526fa629..a371734d 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -150,6 +150,7 @@ end # @ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type) # now in CR 1.18 ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent() +ChainRulesCore.canonicalize(::NoTangent) = NoTangent() # # Skip AD'ing through the axis computation # function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted) @@ -200,12 +201,13 @@ function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractVector, B::Abst map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ) end -function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N} - # We're leaving these in the eltype that the cotangent vector already has. - # There isn't really a good reason to believe we should convert to the - # original array type, so don't unless explicitly requested. - AT(x), Δ->(NoTangent(), Δ) -end +# https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/array.jl#L7 +# function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N} +# # We're leaving these in the eltype that the cotangent vector already has. +# # There isn't really a good reason to believe we should convert to the +# # original array type, so don't unless explicitly requested. +# AT(x), Δ->(NoTangent(), Δ) +# end # WARNING: Method definition rrule(Type{var"#s260"} where var"#s260"<:(Array{T, N} where N where T), UndefInitializer, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Base/array.jl:5 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:209. # function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...) @@ -254,10 +256,9 @@ function ChainRules.frule(_, ::Type{Vector{T}}, undef::UndefInitializer, dims::I Vector{T}(undef, dims...), zeros(T, dims...) end -@ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer) +# @ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer) CR#558 @ChainRules.non_differentiable Base.throw(err) @ChainRules.non_differentiable Core.Compiler.return_type(args...) -ChainRulesCore.canonicalize(::NoTangent) = NoTangent() # Disable thunking at higher order (TODO: These should go into ChainRulesCore) function ChainRulesCore.rrule(::Type{Thunk}, thnk) diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 49333eac..078930b2 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -33,12 +33,12 @@ end using ChainRulesCore: derivatives_given_output -# _print(s) = nothing -_print(s) = printstyled(s, "\n"; color=:magenta) +_print(s) = nothing +# _print(s) = printstyled(s, "\n"; color=:magenta) # Broadcast over one element is just map function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} - _print("path 0") + _print("path 0, order $N") ∂⃖ₙ(map, f, a) end @@ -47,8 +47,8 @@ end function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N} T = Broadcast.combine_eltypes(f, args) TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...}) - if eltype(T) == Bool - # Trivial case: non-differentiable output + if T === Bool + # Trivial case: non-differentiable output, e.g. `x .> 0` _print("path 1") back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) return f.(args...), back_1 @@ -160,16 +160,14 @@ end # For certain cheap operations we can easily allow fused broadcast: -(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...) -(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = split_bc_plus(arg) # ambiguity -function split_bc_plus(xs...) where {F} +(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = lazy_bc_plus(args...) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = lazy_bc_plus(arg) # ambiguity +function lazy_bc_plus(xs...) where {F} broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw) _print("broadcast +") (NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...) end end -Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) = - mapreduce(eltype, promote_type, bc.args) # needed to hit fast path (::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ) @@ -182,24 +180,22 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y) end using LinearAlgebra: dot +const Numeric{T<:Number} = Union{T, AbstractArray{T}} -function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) # should this be vararg, or will laziness handle it? +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw) _print("broadcast *") dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δ) : unbroadcast(x, Δ .* conj.(y)) dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δ) : unbroadcast(y, Δ .* conj.(x)) # When x is an array but a smaller one, instead of dot you may be able to use mapreduce() - # Will things like this work? Ref([1,2]) .* [1,2,3] (NoTangent(), NoTangent(), dx, dy) end end -# Alternative to `x isa Number` etc above... but not quite right! -# (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y::Number) = rrule_via_ad(DiffractorRuleConfig(), *, x, y) function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x, ::Val{2}) _print("broadcast ^2") broadcasted(*, x, x), Δ -> begin - dx = unbroadcast(x, 2 .* Δ .* conj.(x)) + dx = unbroadcast(x, 2 .* unthunk(Δ) .* conj.(x)) (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) end end @@ -208,30 +204,25 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::type x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent()) end -# function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y) # not obvious whether this is better than automatic -# broadcasted(/, x, y), Δ -> let Δun = unthunk(Δ) -# _print("broadcast /") -# dx = unbroadcast(x, Δ ./ conj.(y)) -# dy = unbroadcast(y, .-Δ .* conj.(res ./ y)) -# (NoTangent(), NoTangent(), dx, dy) -# end -# end -function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y::Number) +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Numeric, y::Number) _print("simple /") z, back = ∂⃖{1}()(/, x, y) - z, Δ -> begin - _, dx, dy = back(Δ) - (NoTangent(), NoTangent(), dx, dy) # maybe there should be a funciton for this? Use for conj, identity too + z, dz -> begin + _, dx, dy = back(dz) + (NoTangent(), NoTangent(), dx, dy) end end +(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = x, identity_pullback +(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = x, identity_pullback # ambiguity +identity_pullback(Δ) = (NoTangent(), NoTangent(), Δ) + +(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = x, identity_pullback +(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = x, identity_pullback (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) = broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ))) -(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = - x, Δ -> (NoTangent(), Δ) - -(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = - x, Δ -> (NoTangent(), Δ) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array) = + broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ))) # All broadcasts use `unbroadcast` to reduce to correct shape: @@ -244,7 +235,7 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) ProjectTo(x)(sum(dx; dims)) end end -unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::NoTangent) = NoTangent() +unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx))) function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} diff --git a/test/runtests.jl b/test/runtests.jl index fa2d6085..72ec68f3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -215,42 +215,49 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test delta45 ≈ 1.0 # Broadcasting -@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output -@test gradient(x -> sum(sqrt.(atan.(x, x'))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 -@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) - -@test_broken gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # stores pullback -exp_log(x) = exp(log(x)) -@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) -@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) -@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) -@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure - -@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays -@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 -@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 - -@test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],) -@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule -@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule -@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule - -@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output -@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),) -@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent()) -@test gradient(x -> sum(x .+ [1,2,3]), true) == (NoTangent(),) # Bool input -@test gradient(x -> sum(x ./ [1,2,3]), [true false]) == (NoTangent(),) -@test gradient(x -> sum(x .* [1,2,3]'), (true, false)) == (NoTangent(),) - -tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), [3,4,5]') -@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) -@test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] -@test tup_adj[2] isa Adjoint -@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal - -@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3]) # path 0, MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64}) -@test_broken gradient(x -> sum(gradient(x -> sum(x' .* x), x)[1]), [1,2,3]) == ([6,6,6],) # Control flow support not fully implemented yet for higher-order reverse mode -@test_broken gradient(x -> sum(gradient(x -> sum(x' ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] +@testset "broadcast" begin + @test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output + @test gradient(x -> sum(sqrt.(atan.(x, x'))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 + @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) + + @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # stores pullback + exp_log(x) = exp(log(x)) + @test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) + @test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure + + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays + @test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 + @test gradient(x -> sum(sum, (x,) .* x'), [1,2,3])[1] ≈ [12, 12, 12] # must not take the * fast path + + @test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],) + @test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule + @test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule + @test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule + + @test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),) + @test gradient(x -> sum([1,2,3]' .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),) + @test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),) + + @test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output + @test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),) + @test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent()) + @test gradient(x -> sum(x .+ [1,2,3]), true) == (NoTangent(),) # Bool input + @test gradient(x -> sum(x ./ [1,2,3]), [true false]) == (NoTangent(),) + @test_broken gradient(x -> sum(x .* [1,2,3]'), (true, false)) == (NoTangent(),) # Cannot `convert` an object of type NoTangent to an object of type ZeroTangent + + tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), [3,4,5]') + @test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) + @test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] + @test tup_adj[2] isa Adjoint + @test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal + + @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3]) # path 0, MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64}) + @test_broken gradient(x -> sum(gradient(x -> sum(x' .* x), x)[1]), [1,2,3]) == ([6,6,6],) # Control flow support not fully implemented yet for higher-order reverse mode + @test_broken gradient(x -> sum(gradient(x -> sum(x' ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] +end # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) #include("pinn.jl") From 84989e6430785edf74475a70999527ec2a4c80fb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 20 Jan 2022 01:23:55 -0500 Subject: [PATCH 8/9] rm printing etc. --- src/stage1/broadcast.jl | 70 ----------------------------------------- 1 file changed, 70 deletions(-) diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 078930b2..7cd2c24c 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -33,12 +33,8 @@ end using ChainRulesCore: derivatives_given_output -_print(s) = nothing -# _print(s) = printstyled(s, "\n"; color=:magenta) - # Broadcast over one element is just map function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} - _print("path 0, order $N") ∂⃖ₙ(map, f, a) end @@ -49,13 +45,11 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N} TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...}) if T === Bool # Trivial case: non-differentiable output, e.g. `x .> 0` - _print("path 1") back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) return f.(args...), back_1 elseif T <: Number && isconcretetype(TΔ) # Fast path: just broadcast, and use x & y to find derivative. ys = f.(args...) - _print("path 2") function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all delta = broadcast(unthunk(dys), ys, args...) do dy, y, a das = only(derivatives_given_output(y, f, a)) @@ -76,7 +70,6 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N} # Slow path: collect all the pullbacks & apply them later. # (Since broadcast makes no guarantee about order of calls, and un-fusing # can change the number of calls, this does not bother to try to reverse.) - _print("path 3") ys, backs = splitcast(∂⃖{1}(), f, args...) function back_3(dys) deltas = splitmap(backs, unthunk(dys)) do back, dy @@ -97,74 +90,16 @@ function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast. end # This uses "multimap"-like constructs: - using StructArrays splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...)))) -#= -# This is how you could handle CuArrays, route them to unzip(map(...)) fallback path. -# Maybe 2nd derivatives too, to avoid writing a gradient for splitcast, rule for unzip is easy. - -function Diffractor.splitmap(f, args...) - if any(a -> a isa CuArray, args) - Diffractor._print("unzip splitmap") - unzip(map(f, args...)) - else - StructArrays.components(StructArray(Iterators.map(f, args...))) - end -end -function Diffractor.splitcast(f, args...) - if any(a -> a isa CuArray, args) - Diffractor._print("unzip splitcast") - unzip(broadcast(f, args...)) - else - StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...)))) - end -end - -gradient(x -> sum(log.(x) .+ x'), cu([1,2,3]))[1] -gradient(x -> sum(sqrt.(atan.(x, x'))), cu([1,2,3]))[1] - -=# - -function unzip(xs::AbstractArray) - x1 = first(xs) - x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples")) - N = length(x1) - unzip(xs, Val(N)) # like Zygote's unzip -end -@generated function unzip(xs, ::Val{N}) where {N} - each = [:(map($(Get(i)), xs)) for i in 1:N] - Expr(:tuple, each...) -end -unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy -@generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple} - each = if count(!Base.issingletontype, Ts.parameters) < 2 - # good case, no copy of data, some trivial arrays - [Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters] - else - [:(map($(Get(i)), xs)) for i in 1:length(fieldnames(Ts))] - end - Expr(:tuple, each...) -end - -struct Get{i} end -Get(i) = Get{Int(i)}() -(::Get{i})(x) where {i} = x[i] - -function ChainRulesCore.rrule(::typeof(unzip), xs::AbstractArray) - rezip(dy) = (NoTangent(), tuple.(unthunk(dy)...)) - return unzip(xs), rezip -end - # For certain cheap operations we can easily allow fused broadcast: (::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = lazy_bc_plus(args...) (::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = lazy_bc_plus(arg) # ambiguity function lazy_bc_plus(xs...) where {F} broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw) - _print("broadcast +") (NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...) end end @@ -173,7 +108,6 @@ end function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y) broadcasted(-, x, y), Δraw -> let Δ = unthunk(Δraw) - _print("broadcast -") (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)) # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array end @@ -184,7 +118,6 @@ const Numeric{T<:Number} = Union{T, AbstractArray{T}} function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw) - _print("broadcast *") dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δ) : unbroadcast(x, Δ .* conj.(y)) dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δ) : unbroadcast(y, Δ .* conj.(x)) # When x is an array but a smaller one, instead of dot you may be able to use mapreduce() @@ -193,19 +126,16 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeri end function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x, ::Val{2}) - _print("broadcast ^2") broadcasted(*, x, x), Δ -> begin dx = unbroadcast(x, 2 .* unthunk(Δ) .* conj.(x)) (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) end end function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) - _print("simple ^2") x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent()) end function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Numeric, y::Number) - _print("simple /") z, back = ∂⃖{1}()(/, x, y) z, dz -> begin _, dx, dy = back(dz) From c9716657199cf69bb8f2026b501a0e70cff080f0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 21 Jan 2022 17:24:33 -0500 Subject: [PATCH 9/9] fixup --- src/extra_rules.jl | 43 ++++++------------ src/stage1/broadcast.jl | 99 ++++++++++++++++++++++++----------------- test/runtests.jl | 60 ++++++++++++++++--------- 3 files changed, 108 insertions(+), 94 deletions(-) diff --git a/src/extra_rules.jl b/src/extra_rules.jl index a371734d..164ded8d 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -138,27 +138,17 @@ struct NonDiffOdd{N, O, P}; end (::NonDiffEven{N, O, O})(Δ...) where {N, O} = error() # WARNING: Method definition rrule(typeof(Core.apply_type), Any, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Core/core.jl:10 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:140. -# @Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...) -# Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}() -# end +@Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...) + Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}() +end function ChainRulesCore.rrule(::typeof(Core.tuple), args...) Core.tuple(args...), Δ->Core.tuple(NoTangent(), Δ...) end -# TODO: What to do about these integer rules -# @ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type) # now in CR 1.18 - ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent() ChainRulesCore.canonicalize(::NoTangent) = NoTangent() -# # Skip AD'ing through the axis computation -# function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted) -# return Base.Broadcast.instantiate(bc), Δ->begin -# Core.tuple(NoTangent(), Δ) -# end -# end - using StaticArrays @@ -201,22 +191,6 @@ function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractVector, B::Abst map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ) end -# https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/array.jl#L7 -# function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N} -# # We're leaving these in the eltype that the cotangent vector already has. -# # There isn't really a good reason to believe we should convert to the -# # original array type, so don't unless explicitly requested. -# AT(x), Δ->(NoTangent(), Δ) -# end - -# WARNING: Method definition rrule(Type{var"#s260"} where var"#s260"<:(Array{T, N} where N where T), UndefInitializer, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Base/array.jl:5 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:209. -# function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...) -# # We're leaving these in the eltype that the cotangent vector already has. -# # There isn't really a good reason to believe we should convert to the -# # original array type, so don't unless explicitly requested. -# AT(undef, args...), Δ->(NoTangent(), NoTangent(), ntuple(_->NoTangent(), length(args))...) -# end - function unzip_tuple(t::Tuple) map(x->x[1], t), map(x->x[2], t) end @@ -256,7 +230,6 @@ function ChainRules.frule(_, ::Type{Vector{T}}, undef::UndefInitializer, dims::I Vector{T}(undef, dims...), zeros(T, dims...) end -# @ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer) CR#558 @ChainRules.non_differentiable Base.throw(err) @ChainRules.non_differentiable Core.Compiler.return_type(args...) @@ -273,3 +246,13 @@ end # ERROR: ArgumentError: Tangent for the primal Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}} should be backed by a AbstractDict type, not by NamedTuple{(:data,), Tuple{ChainRulesCore.ZeroTangent}}. ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing # solves that! +# Rather than have a rule for broadcasted 3-arg *, just send it to the efficient path: +ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number) = ((y*z, x*z, x*y),) +function ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number, w::Number) + xy = x*y + zw = z*w + ((y*zw, x*zw, xy*w, xy*z),) +end + +# Fixes @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;] +(project::ProjectTo{<:AbstractArray})(th::InplaceableThunk) = project(unthunk(th)) diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 7cd2c24c..78964184 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -34,12 +34,14 @@ end using ChainRulesCore: derivatives_given_output # Broadcast over one element is just map -function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} - ∂⃖ₙ(map, f, a) -end +# function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} +# ∂⃖ₙ(map, f, a) +# end + +(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ) -(::∂⃖{1})(::typeof(broadcasted), f, args...) = split_bc_rule(f, args...) -(::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity +(::∂⃖{1})(::typeof(broadcasted), f::F, args...) where {F} = split_bc_rule(f, args...) +# (::∂⃖{1})(::typeof(broadcasted), f::F, arg::Array) where {F} = split_bc_rule(f, arg) # ambiguity function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N} T = Broadcast.combine_eltypes(f, args) TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...}) @@ -48,17 +50,17 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N} back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) return f.(args...), back_1 elseif T <: Number && isconcretetype(TΔ) - # Fast path: just broadcast, and use x & y to find derivative. + # Fast path: just broadcast, and use arguments & result to find derivatives. ys = f.(args...) function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all delta = broadcast(unthunk(dys), ys, args...) do dy, y, a das = only(derivatives_given_output(y, f, a)) - dy * conj(only(das)) + dy * conj(only(das)) # possibly this * should be made nan-safe. end (NoTangent(), NoTangent(), unbroadcast(only(args), delta)) end function back_2_many(dys) - deltas = splitcast(unthunk(dys), ys, args...) do dy, y, as... + deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as... das = only(derivatives_given_output(y, f, as...)) map(da -> dy * conj(da), das) end @@ -70,62 +72,76 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N} # Slow path: collect all the pullbacks & apply them later. # (Since broadcast makes no guarantee about order of calls, and un-fusing # can change the number of calls, this does not bother to try to reverse.) - ys, backs = splitcast(∂⃖{1}(), f, args...) + ys3, backs = tuplecast(∂⃖{1}(), f, args...) function back_3(dys) - deltas = splitmap(backs, unthunk(dys)) do back, dy + deltas = tuplecast(backs, unthunk(dys)) do back, dy # could be map, sizes match map(unthunk, back(dy)) end - dargs = map(unbroadcast, args, Base.tail(deltas)) # no real need to close over args here + dargs = map(unbroadcast, args, Base.tail(deltas)) (NoTangent(), sum(first(deltas)), dargs...) end back_3(::AbstractZero) = (NoTangent(), map(Returns(ZeroTangent()), args)...) - return ys, back_3 + return ys3, back_3 end end +# Don't run broadcasting on scalars +function split_bc_rule(f::F, args::Number...) where {F} + z, back = ∂⃖{1}()(f, args...) + z, dz -> (NoTangent(), back(dz)...) +end + +split_bc_rule(::typeof(identity), x) = x, Δ -> (NoTangent(), NoTangent(), Δ) +split_bc_rule(::typeof(identity), x::Number) = x, Δ -> (NoTangent(), NoTangent(), Δ) + # Skip AD'ing through the axis computation function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted) uninstantiate(Δ) = Core.tuple(NoTangent(), Δ) return Base.Broadcast.instantiate(bc), uninstantiate end -# This uses "multimap"-like constructs: using StructArrays -splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) -splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...)))) + +function tuplecast(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("tuplecast(f, args) only works on functions returning a tuple.")) + end + bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)) + StructArrays.components(StructArray(bc)) +end # For certain cheap operations we can easily allow fused broadcast: +const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted} -(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = lazy_bc_plus(args...) -(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = lazy_bc_plus(arg) # ambiguity +(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args::NumericOrBroadcast...) = lazy_bc_plus(args...) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args::Number) = split_bc_rule(+, args...) function lazy_bc_plus(xs...) where {F} broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw) (NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...) end end -(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ) - -function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = split_bc_rule(-, x, y) +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast) broadcasted(-, x, y), Δraw -> let Δ = unthunk(Δraw) (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)) - # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array end end using LinearAlgebra: dot -const Numeric{T<:Number} = Union{T, AbstractArray{T}} -function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = split_bc_rule(*, x, y) +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast) broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw) - dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δ) : unbroadcast(x, Δ .* conj.(y)) - dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δ) : unbroadcast(y, Δ .* conj.(x)) - # When x is an array but a smaller one, instead of dot you may be able to use mapreduce() - (NoTangent(), NoTangent(), dx, dy) + (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ)) end end +_back_star(x, y, Δ) = unbroadcast(x, Δ .* conj.(y)) +_back_star(x::Number, y, Δ) = dot(y, Δ) +_back_star(x::Bool, y, Δ) = NoTangent() -function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x, ::Val{2}) +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2}) broadcasted(*, x, x), Δ -> begin dx = unbroadcast(x, 2 .* unthunk(Δ) .* conj.(x)) (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) @@ -135,41 +151,40 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::type x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent()) end -function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Numeric, y::Number) - z, back = ∂⃖{1}()(/, x, y) - z, dz -> begin - _, dx, dy = back(dz) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = split_bc_rule(/, x, y) +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number) + z = broadcast(/, x, y) + z, Δth -> let Δ = unthunk(Δth) + dx = unbroadcast(x, Δ ./ conj.(y)) + dy = -dot(z, Δ) / (conj(y)) # the reason to be eager is to allow dot here (NoTangent(), NoTangent(), dx, dy) end end -(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = x, identity_pullback -(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = x, identity_pullback # ambiguity -identity_pullback(Δ) = (NoTangent(), NoTangent(), Δ) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = split_bc_rule(identity, x) +# (::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = split_bc_rule(identity, x) # ambiguity -(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = x, identity_pullback -(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = x, identity_pullback +(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = split_bc_rule(identity, x) +# (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = split_bc_rule(identity, x) # ambiguity (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) = broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ))) (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array) = broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ))) -# All broadcasts use `unbroadcast` to reduce to correct shape: - +# Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape: function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) N = ndims(dx) if length(x) == length(dx) ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors else - dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # awful hack to get type-stable `dims` - ProjectTo(x)(sum(dx; dims)) + dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # hack to get type-stable `dims` + ProjectTo(x)(sum(dx; dims)) # ideally this sum might be thunked? end end unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx))) function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} - _print("unbroadcast tuple") val = if length(x) == length(dx) dx else diff --git a/test/runtests.jl b/test/runtests.jl index 72ec68f3..d752fc11 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Diffractor -using Diffractor: var"'", ∂⃖, DiffractorRuleConfig +using Diffractor: ∂⃖, DiffractorRuleConfig using ChainRules using ChainRulesCore using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad @@ -8,6 +8,11 @@ using LinearAlgebra using Test +const fwd = Diffractor.PrimeDerivativeFwd +const bwd = Diffractor.PrimeDerivativeBack + +@testset verbose=true "Unit tests" begin # from before broadcasting PR + # Unit tests function tup2(f) a, b = ∂⃖{2}()(f, 1) @@ -43,7 +48,7 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent() # Minimal 2-nd order forward smoke test @test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), - Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) + Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == cos(1.0) function simple_control_flow(b, x) if b @@ -86,7 +91,7 @@ isa_control_flow(::Type{T}, x) where {T} = isa(x, T) ? x : T(x) let var"'" = Diffractor.PrimeDerivativeBack # Integration tests @test @inferred(sin'(1.0)) == cos(1.0) - @test @inferred(sin''(1.0)) == -sin(1.0) + @test @inferred(sin''(1.0)) == -sin(1.0) broken=true @test sin'''(1.0) == -cos(1.0) @test sin''''(1.0) == sin(1.0) @test sin'''''(1.0) == cos(1.0) @@ -100,10 +105,15 @@ let var"'" = Diffractor.PrimeDerivativeBack # Higher order mixed mode tests complicated_2sin(x) = (x = map(sin, Diffractor.xfill(x, 2)); x[1] + x[2]) - @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) - @test @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) - @test @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) - @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) + @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) broken=true # inference broken on Julia nightly without PR68 + @test @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) broken=true + @test @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) broken=true + @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) broken=true + + @test complicated_2sin'(1.0) == 2sin'(1.0) + @test complicated_2sin''(1.0) == 2sin''(1.0) + @test complicated_2sin'''(1.0) == 2sin'''(1.0) + @test complicated_2sin''''(1.0) == 2sin''''(1.0) # Control flow cases @test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0) @@ -149,9 +159,6 @@ end # Regression tests @test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0] -const fwd = Diffractor.PrimeDerivativeFwd -const bwd = Diffractor.PrimeDerivativeBack - function f_broadcast(a) l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]] return sum(l) @@ -161,7 +168,7 @@ end # Make sure that there's no infinite recursion in kwarg calls g_kw(;x=1.0) = sin(x) f_kw(x) = g_kw(;x) -@test bwd(f_kw)(1.0) == bwd(sin)(1.0) +@test bwd(f_kw)(1.0) == bwd(sin)(1.0) broken=true # involves [3] +(a::Tangent{Core.Box, ...}, b::Tangent{Core.Box, ...} function f_crit_edge(a, b, c, x) # A function with two critical edges. This used to trigger an issue where @@ -214,10 +221,12 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test z45 ≈ 2.0 @test delta45 ≈ 1.0 -# Broadcasting -@testset "broadcast" begin + +end # @testset + +@testset verbose=true "broadcast" begin @test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output - @test gradient(x -> sum(sqrt.(atan.(x, x'))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 + @test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # stores pullback @@ -230,7 +239,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays @test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 - @test gradient(x -> sum(sum, (x,) .* x'), [1,2,3])[1] ≈ [12, 12, 12] # must not take the * fast path + @test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] ≈ [12, 12, 12] # must not take the * fast path @test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],) @test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule @@ -238,7 +247,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule @test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),) - @test gradient(x -> sum([1,2,3]' .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),) + @test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),) @test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),) @test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output @@ -246,17 +255,24 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent()) @test gradient(x -> sum(x .+ [1,2,3]), true) == (NoTangent(),) # Bool input @test gradient(x -> sum(x ./ [1,2,3]), [true false]) == (NoTangent(),) - @test_broken gradient(x -> sum(x .* [1,2,3]'), (true, false)) == (NoTangent(),) # Cannot `convert` an object of type NoTangent to an object of type ZeroTangent + @test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) == (NoTangent(),) - tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), [3,4,5]') + tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5])) @test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) @test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] - @test tup_adj[2] isa Adjoint + @test tup_adj[2] isa Transpose @test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal - @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3]) # path 0, MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64}) - @test_broken gradient(x -> sum(gradient(x -> sum(x' .* x), x)[1]), [1,2,3]) == ([6,6,6],) # Control flow support not fully implemented yet for higher-order reverse mode - @test_broken gradient(x -> sum(gradient(x -> sum(x' ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] + @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] ≈ exp.(1:3) # path 0, order 2, MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64}) + @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3.0])[1] ≈ exp.(1:3) # path 0, order 2, MethodError: no method matching length(::ChainRulesCore.InplaceableThunk + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # Control flow support not fully implemented yet for higher-order reverse mode + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] + + @test (@inferred Diffractor.split_bc_rule(<, [1,2,3], [4 5]); true) # path 1, bool + @test (@inferred Diffractor.split_bc_rule(+, [1,2,3]); true) # path 2, derivatives_given_output + @test (@inferred Diffractor.split_bc_rule(+, [1,2,3], [4 5]); true) # path 2 vararg, tuplecast + @test (@inferred Diffractor.split_bc_rule(exp_log, [1,2,3]); true) # path 3 generic + end # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)