Skip to content

Commit ebd1700

Browse files
committed
update, rm comments
1 parent b484a14 commit ebd1700

File tree

4 files changed

+94
-126
lines changed

4 files changed

+94
-126
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1010
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
11+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1213
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1314
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1415

1516
[compat]
16-
ChainRules = "1.5"
17-
ChainRulesCore = "1.4"
17+
ChainRules = "1.17"
18+
ChainRulesCore = "1.11"
1819
Combinatorics = "1"
1920
StaticArrays = "1"
2021
StatsBase = "0.33"

src/extra_rules.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,10 @@ struct NonDiffOdd{N, O, P}; end
137137
# This should not happen
138138
(::NonDiffEven{N, O, O})(Δ...) where {N, O} = error()
139139

140-
@Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...)
141-
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
142-
end
140+
# WARNING: Method definition rrule(typeof(Core.apply_type), Any, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Core/core.jl:10 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:140.
141+
# @Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...)
142+
# Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
143+
# end
143144

144145
function ChainRulesCore.rrule(::typeof(Core.tuple), args...)
145146
Core.tuple(args...), Δ->Core.tuple(NoTangent(), Δ...)
@@ -206,12 +207,13 @@ function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {
206207
AT(x), Δ->(NoTangent(), Δ)
207208
end
208209

209-
function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...)
210-
# We're leaving these in the eltype that the cotangent vector already has.
211-
# There isn't really a good reason to believe we should convert to the
212-
# original array type, so don't unless explicitly requested.
213-
AT(undef, args...), Δ->(NoTangent(), NoTangent(), ntuple(_->NoTangent(), length(args))...)
214-
end
210+
# WARNING: Method definition rrule(Type{var"#s260"} where var"#s260"<:(Array{T, N} where N where T), UndefInitializer, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Base/array.jl:5 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:209.
211+
# function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...)
212+
# # We're leaving these in the eltype that the cotangent vector already has.
213+
# # There isn't really a good reason to believe we should convert to the
214+
# # original array type, so don't unless explicitly requested.
215+
# AT(undef, args...), Δ->(NoTangent(), NoTangent(), ntuple(_->NoTangent(), length(args))...)
216+
# end
215217

216218
function unzip_tuple(t::Tuple)
217219
map(x->x[1], t), map(x->x[2], t)

src/stage1/broadcast.jl

Lines changed: 53 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
2929
return r
3030
end
3131

32+
_print(s) = nothing
33+
# _print(s) = printstyled(s, "\n"; color=:magenta)
34+
3235
# Broadcast over one element is just map
3336
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
37+
_print("path 0")
3438
∂⃖ₙ(map, f, a)
3539
end
3640

@@ -40,16 +44,16 @@ using ChainRulesCore: derivatives_given_output
4044
(::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity
4145
function split_bc_rule(f::F, args...) where {F}
4246
T = Broadcast.combine_eltypes(f, args)
43-
if T == Bool && Base.issingletontype(F)
47+
if T == Bool
4448
# Trivial case
49+
_print("path 1")
4550
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
4651
return f.(args...), back_1
47-
# elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type(
4852
elseif isconcretetype(Core.Compiler._return_type(
4953
derivatives_given_output, Tuple{T, F, map(eltype, args)...}))
5054
# Fast path: just broadcast, and use x & y to find derivative.
5155
ys = f.(args...)
52-
# println("2")
56+
_print("path 2")
5357
function back_2(dys)
5458
deltas = splitcast(unthunk(dys), ys, args...) do dy, y, as...
5559
das = only(derivatives_given_output(y, f, as...))
@@ -61,7 +65,7 @@ function split_bc_rule(f::F, args...) where {F}
6165
return ys, back_2
6266
else
6367
# Slow path: collect all the pullbacks & apply them later.
64-
# println("3")
68+
_print("path 3")
6569
ys, backs = splitcast(rrule_via_ad, DiffractorRuleConfig(), f, args...)
6670
function back_3(dys)
6771
deltas = splitmap(backs, unthunk(dys)) do back, dy
@@ -78,108 +82,13 @@ using StructArrays
7882
splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
7983
splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
8084

81-
unbroadcast(f::Function, x̄) = accum_sum(x̄)
82-
unbroadcast(::Val, _) = NoTangent()
83-
unbroadcast(x::AbstractArray, x̄::NoTangent) = NoTangent()
84-
accum_sum(xs::AbstractArray{<:NoTangent}; dims = :) = NoTangent()
85-
86-
#=
87-
88-
julia> xs = randn(10_000);
89-
julia> @btime Zygote.gradient(x -> sum(abs2, x), $xs)
90-
4.744 μs (2 allocations: 78.17 KiB)
91-
julia> @btime Diffractor.unthunk.(gradient(x -> sum(abs2, x), $xs));
92-
3.307 μs (2 allocations: 78.17 KiB)
93-
94-
# Simple function
95-
96-
julia> @btime Zygote.gradient(x -> sum(abs2, exp.(x)), $xs);
97-
72.541 μs (29 allocations: 391.47 KiB) # with dual numbers -- like 4 copies
98-
99-
julia> @btime gradient(x -> sum(abs2, exp.(x)), $xs);
100-
45.875 μs (36 allocations: 235.47 KiB) # fast path -- one copy forward, one back
101-
44.042 μs (32 allocations: 313.48 KiB) # slow path -- 3 copies, extra is closure?
102-
61.167 μs (12 allocations: 703.41 KiB) # with `map` rule as before -- worse
103-
104-
105-
# Composed function, Zygote struggles
106-
107-
julia> @btime Zygote.gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
108-
97.167 μs (29 allocations: 391.61 KiB) # with dual numbers (Zygote master)
109-
93.238 ms (849567 allocations: 19.22 MiB) # without, thus Zygote.pullback
110-
111-
julia> @btime gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
112-
55.290 ms (830060 allocations: 49.75 MiB) # slow path
113-
14.747 ms (240043 allocations: 7.25 MiB) # with `map` rule as before -- better!
114-
115-
# Compare unfused
116-
117-
julia> @btime gradient(x -> sum(abs2, identity.(cbrt.(x))), $xs);
118-
69.458 μs (50 allocations: 392.09 KiB) # fast path -- two copies forward, two back
119-
75.041 μs (46 allocations: 470.11 KiB) # slow path -- 5 copies
120-
135.541 μs (27 allocations: 1.30 MiB) # with `map` rule as before -- worse
121-
122-
123-
# Lazy +,-,* for partial fusing
124-
125-
julia> @btime Zygote.gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
126-
81.250 μs (21 allocations: 625.47 KiB) # special rules + dual numbers, 4 more copies than best
127-
128-
julia> @btime gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
129-
57.166 μs (49 allocations: 470.22 KiB) # broadcast in *, -
130-
54.583 μs (46 allocations: 314.06 KiB) # broadcasted -- two less copies
131-
72.958 μs (26 allocations: 1016.38 KiB) # with `map` rule as before
132-
133-
julia> gradient((x,y) -> sum(abs2, exp.(2 .* x .+ y)), xs, (rand(10)'))
134-
ERROR: MethodError: no method matching size(::Base.Broadcast.Broadcasted # hmm
135-
136-
julia> @btime gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)'));
137-
7.127 ms (75 allocations: 22.97 MiB) # after
138-
12.956 ms (57 allocations: 76.37 MiB) # before
139-
140-
ulia> @btime Zygote.gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)'));
141-
9.937 ms (48 allocations: 45.86 MiB)
142-
143-
=#
85+
# For certain cheap operations we can easily allow fused broadcast:
14486

145-
# The below is from Zygote: TODO: DO we want to do something better here?
146-
147-
accum_sum(xs::Nothing; dims = :) = NoTangent()
148-
accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent()
149-
accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims)
150-
accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims)
151-
accum_sum(xs::Number; dims = :) = xs
152-
153-
# https://github.com/FluxML/Zygote.jl/issues/594
154-
function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region)
155-
Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)})
156-
end
157-
158-
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
159-
160-
unbroadcast(x::Union{AbstractArray, Base.Broadcast.Broadcasted}, x̄) =
161-
size(x) == size(x̄) ?:
162-
length(x) == length(x̄) ? trim(x, x̄) :
163-
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
164-
165-
unbroadcast(x::Number, x̄) = accum_sum(x̄)
166-
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
167-
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
168-
169-
unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()
170-
171-
const Numeric = Union{Number, AbstractArray{<:Number, N} where N}
172-
173-
# function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...)
174-
# broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
175-
# end
176-
177-
# Replace Zygote-like fully split broadcasting with one fused over easy operations
17887
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...)
17988
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = split_bc_plus(arg) # ambiguity
18089
function split_bc_plus(xs...) where {F}
18190
broadcasted(+, xs...), Δ -> let Δun = unthunk(Δ)
182-
# println("+")
91+
_print("broadcast +")
18392
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δun), xs)...)
18493
end
18594
end
@@ -188,28 +97,61 @@ Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) =
18897

18998
(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ)
19099

191-
# ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
192-
# Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
193-
194100
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y)
195101
broadcasted(-, x, y), Δ -> let Δun = unthunk(Δ)
196-
# println("-")
102+
_print("broadcast -")
197103
(NoTangent(), NoTangent(), unbroadcast(x, Δun), -unbroadcast(y, Δun))
198104
# Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
199105
end
200106
end
201107

202-
# ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
203-
# z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end
204-
205108
using LinearAlgebra: dot
206109

207110
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y)
208111
broadcasted(*, x, y), Δ -> let Δun = unthunk(Δ)
209-
# println("*")
210-
dx = x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y))
211-
dy = y isa Number ? dot(x, Δun) : unbroadcast(y, Δun .* conj.(x))
112+
_print("broadcast *")
113+
dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y))
114+
dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δun) : unbroadcast(y, Δun .* conj.(x))
212115
# When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
116+
# Will things like this work? Ref([1,2]) .* [1,2,3]
213117
(NoTangent(), NoTangent(), dx, dy)
214118
end
215119
end
120+
121+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) =
122+
broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ)))
123+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) =
124+
x, Δ -> (NoTangent(), Δ)
125+
126+
function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
127+
N = ndims(dx)
128+
if length(x) == length(dx)
129+
ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
130+
else
131+
# This is an awful hack to get type-stable `dims`
132+
dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N)
133+
ProjectTo(x)(sum(dx; dims))
134+
end
135+
end
136+
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::NoTangent) = NoTangent()
137+
138+
unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx))
139+
unbroadcast(f::Function, df) = ProjectTo(x)(sum(df))
140+
unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx)))
141+
142+
unbroadcast(::Bool, dx) = NoTangent()
143+
unbroadcast(::AbstractArray{Bool}, dx) = NoTangent()
144+
unbroadcast(::AbstractArray{Bool}, ::NoTangent) = NoTangent() # ambiguity
145+
unbroadcast(::Val, dx) = NoTangent()
146+
# Maybe more non-diff types? Some fallback?
147+
148+
unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx)))
149+
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
150+
_print("unbroadcast tuple")
151+
val = if length(x) == length(dx)
152+
dx
153+
else
154+
sum(dx; dims=2:ndims(dx))
155+
end
156+
ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
157+
end

test/runtests.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,36 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
215215
@test delta45 1.0
216216

217217
# Broadcasting
218-
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([1,1,1],)
219-
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) # derivatives_given_output
220-
@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # stores pullback
218+
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
219+
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
220+
221+
@test_broken gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # stores pullback
222+
exp_log(x) = exp(log(x))
223+
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
221224
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
222225
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
223-
@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool shortcut
226+
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure
227+
228+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays
229+
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
230+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
231+
232+
@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output
233+
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),)
224234
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent())
235+
@test gradient(x -> sum(x .+ [1,2,3]), true) == (NoTangent(),) # Bool input
236+
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) == (NoTangent(),)
237+
@test gradient(x -> sum(x .* [1,2,3]'), (true, false)) == (NoTangent(),)
238+
239+
tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), [3,4,5]')
240+
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0)
241+
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
242+
@test tup_adj[2] isa Adjoint
243+
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal
244+
245+
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3]) # path 0, MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64})
246+
@test_broken gradient(x -> sum(gradient(x -> sum(x' .* x), x)[1]), [1,2,3]) == ([6,6,6],) # Control flow support not fully implemented yet for higher-order reverse mode
247+
@test_broken gradient(x -> sum(gradient(x -> sum(x' ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]
225248

226249
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
227250
#include("pinn.jl")

0 commit comments

Comments
 (0)