@@ -34,6 +34,90 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
34
34
∂⃖ₙ (map, f, a)
35
35
end
36
36
37
+ using ChainRulesCore: derivatives_given_output
38
+
39
+ (:: ∂⃖{1 })(:: typeof (broadcasted), f, args... ) = split_bc_rule (f, args... )
40
+ (:: ∂⃖{1 })(:: typeof (broadcasted), f, arg:: Array ) = split_bc_rule (f, arg) # ambiguity
41
+ function split_bc_rule (f:: F , args... ) where {F}
42
+ T = Broadcast. combine_eltypes (f, args)
43
+ if T == Bool && Base. issingletontype (F)
44
+ # Trivial case
45
+ back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
46
+ return f .(args... ), back_1
47
+ elseif all (a -> a isa Numeric, args) && isconcretetype (Core. Compiler. _return_type (
48
+ derivatives_given_output, Tuple{T, F, map (eltype, args)... }))
49
+ # Fast path: just broadcast, and use x & y to find derivative.
50
+ ys = f .(args... )
51
+ # println("2")
52
+ function back_2 (dys)
53
+ deltas = splitcast (unthunk (dys), ys, args... ) do dy, y, as...
54
+ das = only (derivatives_given_output (y, f, as... ))
55
+ map (da -> dy * conj (da), das)
56
+ end
57
+ dargs = map (unbroadcast, args, deltas)
58
+ (NoTangent (), NoTangent (), dargs... )
59
+ end
60
+ return ys, back_2
61
+ else
62
+ # Slow path: collect all the pullbacks & apply them later.
63
+ # println("3")
64
+ ys, backs = splitcast (rrule_via_ad, DiffractorRuleConfig (), f, args... )
65
+ function back_3 (dys)
66
+ deltas = splitmap (backs, unthunk (dys)) do back, dy
67
+ map (unthunk, back (dy))
68
+ end
69
+ dargs = map (unbroadcast, args, Base. tail (deltas)) # no real need to close over args here
70
+ (NoTangent (), sum (first (deltas)), dargs... )
71
+ end
72
+ return ys, back_3
73
+ end
74
+ end
75
+
76
+ using StructArrays
77
+ splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... ))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
78
+ splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
79
+
80
+ unbroadcast (f:: Function , x̄) = accum_sum (x̄)
81
+ unbroadcast (:: Val , _) = NoTangent ()
82
+ accum_sum (xs:: AbstractArray{<:NoTangent} ; dims = :) = NoTangent ()
83
+
84
+ #=
85
+
86
+ julia> xs = randn(10_000);
87
+ julia> @btime Zygote.gradient(x -> sum(abs2, x), $xs)
88
+ 4.744 μs (2 allocations: 78.17 KiB)
89
+ julia> @btime Diffractor.unthunk.(gradient(x -> sum(abs2, x), $xs));
90
+ 3.307 μs (2 allocations: 78.17 KiB)
91
+
92
+ # Simple function
93
+
94
+ julia> @btime Zygote.gradient(x -> sum(abs2, exp.(x)), $xs);
95
+ 72.541 μs (29 allocations: 391.47 KiB) # with dual numbers -- like 4 copies
96
+
97
+ julia> @btime gradient(x -> sum(abs2, exp.(x)), $xs);
98
+ 45.875 μs (36 allocations: 235.47 KiB) # fast path -- one copy forward, one back
99
+ 44.042 μs (32 allocations: 313.48 KiB) # slow path -- 3 copies, extra is closure?
100
+ 61.167 μs (12 allocations: 703.41 KiB) # with `map` rule as before -- worse
101
+
102
+ # Composed function, Zygote struggles
103
+
104
+ julia> @btime Zygote.gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
105
+ 97.167 μs (29 allocations: 391.61 KiB) # with dual numbers (Zygote master)
106
+ 93.238 ms (849567 allocations: 19.22 MiB) # without, thus Zygote.pullback
107
+
108
+ julia> @btime gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
109
+ 55.290 ms (830060 allocations: 49.75 MiB) # slow path
110
+ 14.747 ms (240043 allocations: 7.25 MiB) # with `map` rule as before -- better!
111
+
112
+ # Compare unfused
113
+
114
+ julia> @btime gradient(x -> sum(abs2, identity.(cbrt.(x))), $xs);
115
+ 69.458 μs (50 allocations: 392.09 KiB) # fast path -- two copies forward, two back
116
+ 75.041 μs (46 allocations: 470.11 KiB) # slow path -- 5 copies
117
+ 135.541 μs (27 allocations: 1.30 MiB) # with `map` rule as before -- worse
118
+
119
+ =#
120
+
37
121
# The below is from Zygote: TODO : DO we want to do something better here?
38
122
39
123
accum_sum (xs:: Nothing ; dims = :) = NoTangent ()
@@ -70,4 +154,4 @@ ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric)
70
154
Δ -> let Δ= unthunk (Δ); (NoTangent (), NoTangent (), unbroadcast (x, Δ), - unbroadcast (y, Δ)); end
71
155
72
156
ChainRulesCore. rrule (:: typeof (broadcasted), :: typeof (* ), x:: Numeric , y:: Numeric ) = x.* y,
73
- z̄ -> let z̄= unthunk (z̄); (NoTangent (), NoTangent (), unbroadcast (x, z̄ .* conj .(y)), unbroadcast (y, z̄ .* conj .(x))); end
157
+ z̄ -> let z̄= unthunk (z̄); (NoTangent (), NoTangent (), unbroadcast (x, z̄ .* conj .(y)), unbroadcast (y, z̄ .* conj .(x))); end
0 commit comments