Skip to content

Commit 485c038

Browse files
committed
add split_bc_rule
1 parent bc22ad6 commit 485c038

File tree

3 files changed

+95
-2
lines changed

3 files changed

+95
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1414

1515
[compat]
1616
ChainRules = "1.5"
17-
ChainRulesCore = "1.2"
17+
ChainRulesCore = "1.4"
1818
Combinatorics = "1"
1919
StaticArrays = "1"
2020
StatsBase = "0.33"

src/stage1/broadcast.jl

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,90 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
3434
∂⃖ₙ(map, f, a)
3535
end
3636

37+
using ChainRulesCore: derivatives_given_output
38+
39+
(::∂⃖{1})(::typeof(broadcasted), f, args...) = split_bc_rule(f, args...)
40+
(::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity
41+
function split_bc_rule(f::F, args...) where {F}
42+
T = Broadcast.combine_eltypes(f, args)
43+
if T == Bool && Base.issingletontype(F)
44+
# Trivial case
45+
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
46+
return f.(args...), back_1
47+
elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type(
48+
derivatives_given_output, Tuple{T, F, map(eltype, args)...}))
49+
# Fast path: just broadcast, and use x & y to find derivative.
50+
ys = f.(args...)
51+
# println("2")
52+
function back_2(dys)
53+
deltas = splitcast(unthunk(dys), ys, args...) do dy, y, as...
54+
das = only(derivatives_given_output(y, f, as...))
55+
map(da -> dy * conj(da), das)
56+
end
57+
dargs = map(unbroadcast, args, deltas)
58+
(NoTangent(), NoTangent(), dargs...)
59+
end
60+
return ys, back_2
61+
else
62+
# Slow path: collect all the pullbacks & apply them later.
63+
# println("3")
64+
ys, backs = splitcast(rrule_via_ad, DiffractorRuleConfig(), f, args...)
65+
function back_3(dys)
66+
deltas = splitmap(backs, unthunk(dys)) do back, dy
67+
map(unthunk, back(dy))
68+
end
69+
dargs = map(unbroadcast, args, Base.tail(deltas)) # no real need to close over args here
70+
(NoTangent(), sum(first(deltas)), dargs...)
71+
end
72+
return ys, back_3
73+
end
74+
end
75+
76+
using StructArrays
77+
splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
78+
splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
79+
80+
unbroadcast(f::Function, x̄) = accum_sum(x̄)
81+
unbroadcast(::Val, _) = NoTangent()
82+
accum_sum(xs::AbstractArray{<:NoTangent}; dims = :) = NoTangent()
83+
84+
#=
85+
86+
julia> xs = randn(10_000);
87+
julia> @btime Zygote.gradient(x -> sum(abs2, x), $xs)
88+
4.744 μs (2 allocations: 78.17 KiB)
89+
julia> @btime Diffractor.unthunk.(gradient(x -> sum(abs2, x), $xs));
90+
3.307 μs (2 allocations: 78.17 KiB)
91+
92+
# Simple function
93+
94+
julia> @btime Zygote.gradient(x -> sum(abs2, exp.(x)), $xs);
95+
72.541 μs (29 allocations: 391.47 KiB) # with dual numbers -- like 4 copies
96+
97+
julia> @btime gradient(x -> sum(abs2, exp.(x)), $xs);
98+
45.875 μs (36 allocations: 235.47 KiB) # fast path -- one copy forward, one back
99+
44.042 μs (32 allocations: 313.48 KiB) # slow path -- 3 copies, extra is closure?
100+
61.167 μs (12 allocations: 703.41 KiB) # with `map` rule as before -- worse
101+
102+
# Composed function, Zygote struggles
103+
104+
julia> @btime Zygote.gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
105+
97.167 μs (29 allocations: 391.61 KiB) # with dual numbers (Zygote master)
106+
93.238 ms (849567 allocations: 19.22 MiB) # without, thus Zygote.pullback
107+
108+
julia> @btime gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
109+
55.290 ms (830060 allocations: 49.75 MiB) # slow path
110+
14.747 ms (240043 allocations: 7.25 MiB) # with `map` rule as before -- better!
111+
112+
# Compare unfused
113+
114+
julia> @btime gradient(x -> sum(abs2, identity.(cbrt.(x))), $xs);
115+
69.458 μs (50 allocations: 392.09 KiB) # fast path -- two copies forward, two back
116+
75.041 μs (46 allocations: 470.11 KiB) # slow path -- 5 copies
117+
135.541 μs (27 allocations: 1.30 MiB) # with `map` rule as before -- worse
118+
119+
=#
120+
37121
# The below is from Zygote: TODO: DO we want to do something better here?
38122

39123
accum_sum(xs::Nothing; dims = :) = NoTangent()
@@ -70,4 +154,4 @@ ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric)
70154
Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
71155

72156
ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
73-
-> let=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end
157+
-> let=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,4 +214,13 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
214214
@test z45 2.0
215215
@test delta45 1.0
216216

217+
# Broadcasting
218+
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([1,1,1],)
219+
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) # derivatives_given_output
220+
@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # stores pullback
221+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
222+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
223+
@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool shortcut
224+
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent())
225+
217226
include("pinn.jl")

0 commit comments

Comments
 (0)