@@ -18,56 +18,121 @@ _print(args...) = nothing # println(join(args, " ")) #
18
18
# #### Split broadcasting
19
19
# ####
20
20
21
+ # For `z = g.(f.(xs))`, this finds `y = f.(x)` eagerly because the rules for either `f` or `g` may need it,
22
+ # and we don't know whether re-computing `y` is cheap.
23
+ # (We could check `f` first like `sum(f, x)` does, but checking whether `g` needs `y` is tricky.)
24
+
21
25
function rrule (cfg:: RCR , :: typeof (broadcasted), f:: F , args:: Vararg{Any,N} ) where {F,N}
22
- # = split_bc_rule(cfg, f, args...)
23
- # function split_bc_rule(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
24
26
T = Broadcast. combine_eltypes (f, args)
25
- TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (eltype, args)... })
26
- if T === Bool
27
+ if T === Bool # TODO use nondifftype here
27
28
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
28
- _print (" split_bc_rule 1 " , f)
29
- back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
30
- return f .(args... ), back_1
31
- elseif T <: Number && isconcretetype (TΔ)
32
- # 2: Fast path: just broadcast, and use arguments & result to find derivatives.
33
- _print (" split_bc_rule 2" , f, N)
34
- ys = f .(args... )
35
- function back_2_one (dys) # For f.(x) we do not need StructArrays / unzip at all
36
- delta = broadcast (unthunk (dys), ys, args... ) do dy, y, a
37
- das = only (derivatives_given_output (y, f, a))
38
- dy * conj (only (das)) # possibly this * should be made nan-safe.
39
- end
40
- (NoTangent (), NoTangent (), ProjectTo (only (args))(delta))
29
+ _print (" split_bc_trivial" , f)
30
+ bc_trivial_back (_) = (NoTangent (), NoTangent (), ntuple (Returns (ZeroTangent ()), length (args))... )
31
+ return f .(args... ), bc_trivial_back
32
+ elseif T <: Number && may_bc_derivatives (T, f, args... )
33
+ # 2: Fast path: use arguments & result to find derivatives.
34
+ return split_bc_derivatives (f, args... )
35
+ elseif T <: Number && may_bc_forwards (cfg, f, args... )
36
+ # 3: Future path: use `frule_via_ad`?
37
+ return split_bc_forwards (cfg, f, args... )
38
+ else
39
+ # 4: Slow path: collect all the pullbacks & apply them later.
40
+ return split_bc_pullbacks (cfg, f, args... )
41
+ end
42
+ end
43
+
44
+ # Path 2: This is roughly what `derivatives_given_output` is designed for, should be fast.
45
+
46
+ function may_bc_derivatives (:: Type{T} , f:: F , args:: Vararg{Any,N} ) where {T,F,N}
47
+ TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (_eltype, args)... })
48
+ return isconcretetype (TΔ)
49
+ end
50
+
51
+ _eltype (x) = eltype (x) # ... but try harder to avoid `eltype(Broadcast.broadcasted(+, [1,2,3], 4.5)) == Any`:
52
+ _eltype (bc:: Broadcast.Broadcasted ) = Broadcast. combine_eltypes (bc. f, bc. args)
53
+
54
+ function split_bc_derivatives (f:: F , arg) where {F}
55
+ _print (" split_bc_derivative" , f)
56
+ ys = f .(arg)
57
+ function bc_one_back (dys) # For f.(x) we do not need StructArrays / unzip at all
58
+ delta = broadcast (unthunk (dys), ys, arg) do dy, y, a
59
+ das = only (derivatives_given_output (y, f, a))
60
+ dy * conj (only (das)) # possibly this * should be made nan-safe.
41
61
end
42
- back_2_one (z:: AbstractZero ) = (NoTangent (), NoTangent (), z)
43
- function back_2_many (dys)
44
- deltas = tuplecast (unthunk (dys), ys, args... ) do dy, y, as...
45
- das = only (derivatives_given_output (y, f, as... ))
46
- map (da -> dy * conj (da), das)
47
- end
48
- dargs = map (unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast?
49
- (NoTangent (), NoTangent (), dargs... )
62
+ return (NoTangent (), NoTangent (), ProjectTo (arg)(delta))
63
+ end
64
+ bc_one_back (z:: AbstractZero ) = (NoTangent (), NoTangent (), z)
65
+ return ys, bc_one_back
66
+ end
67
+ function split_bc_derivatives (f:: F , args:: Vararg{Any,N} ) where {F,N}
68
+ _print (" split_bc_derivatives" , f, N)
69
+ ys = f .(args... )
70
+ function bc_many_back (dys)
71
+ deltas = tuplecast (unthunk (dys), ys, args... ) do dy, y, as...
72
+ das = only (derivatives_given_output (y, f, as... ))
73
+ map (da -> dy * conj (da), das) # possibly this * should be made nan-safe.
50
74
end
51
- back_2_many (z:: AbstractZero ) = (NoTangent (), NoTangent (), map (Returns (z), args)... )
52
- return ys, N== 1 ? back_2_one : back_2_many
53
- else
54
- _print (" split_bc_rule 3" , f, N)
55
- # 3: Slow path: collect all the pullbacks & apply them later.
56
- # (Since broadcast makes no guarantee about order of calls, and un-fusing
57
- # can change the number of calls, don't bother to try to reverse the iteration.)
58
- ys3, backs = tuplecast (args... ) do a...
59
- rrule_via_ad (cfg, f, a... )
75
+ dargs = map (unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast?
76
+ return (NoTangent (), NoTangent (), dargs... )
77
+ end
78
+ bc_many_back (z:: AbstractZero ) = (NoTangent (), NoTangent (), map (Returns (z), args)... )
79
+ return ys, bc_many_back
80
+ end
81
+
82
+ # Path 3: Use forward mode, or an `frule` if one exists.
83
+ # To allow `args...` we need either chunked forward mode, with `adot::Tuple` perhaps:
84
+ # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92
85
+ # https://github.com/JuliaDiff/Diffractor.jl/pull/54
86
+ # Or else we need to call the `f` multiple times, and maybe that's OK:
87
+ # We do know that `f` doesn't have parameters, so maybe it's pure enough,
88
+ # and split broadcasting may anyway change N^2 executions into N, e.g. `g.(v ./ f.(v'))`.
89
+ # We don't know `f` is cheap, but `split_bc_pullbacks` tends to be very slow.
90
+
91
+ function may_bc_forwards (cfg:: C , f:: F , args:: Vararg{Any,N} ) where {C,F,N}
92
+ Base. issingletontype (F) || return false
93
+ N== 1 || return false # Could weaken this to 1 differentiable
94
+ cfg isa RuleConfig{>: HasForwardsMode } && return true # allows frule_via_ad
95
+ TA = map (_eltype, args)
96
+ TF = Core. Compiler. _return_type (frule, Tuple{C, Tuple{NoTangent, TA... }, F, TA... })
97
+ return isconcretetype (TF) && TF <: Tuple
98
+ end
99
+
100
+ split_bc_forwards (cfg:: RuleConfig{>:HasForwardsMode} , f:: F , arg) where {F} = split_bc_inner (frule_via_ad, cfg, f, arg)
101
+ split_bc_forwards (cfg:: RuleConfig , f:: F , arg) where {F} = split_bc_inner (frule, cfg, f, arg)
102
+ function split_bc_inner (frule_fun:: R , cfg:: RuleConfig , f:: F , arg) where {R,F}
103
+ _print (" split_bc_forwards" , frule_fun, f)
104
+ ys, ydots = tuplecast (arg) do a
105
+ frule_fun (cfg, (NoTangent (), one (a)), f, a)
106
+ end
107
+ function back_forwards (dys)
108
+ delta = broadcast (ydots, unthunk (dys), arg) do ydot, dy, a
109
+ ProjectTo (a)(conj (ydot) * dy) # possibly this * should be made nan-safe.
60
110
end
61
- function back_3 (dys)
62
- deltas = tuplecast (backs, unthunk (dys)) do back, dy # could be map, sizes match
63
- map (unthunk, back (dy))
64
- end
65
- dargs = map (unbroadcast, args, Base. tail (deltas))
66
- (NoTangent (), ProjectTo (f)(sum (first (deltas))), dargs... )
111
+ return (NoTangent (), NoTangent (), ProjectTo (arg)(delta))
112
+ end
113
+ back_forwards (z:: AbstractZero ) = (NoTangent (), NoTangent (), z)
114
+ return ys, back_forwards
115
+ end
116
+
117
+ # Path 4: The most generic, save all the pullbacks. Can be 1000x slower.
118
+ # Since broadcast makes no guarantee about order of calls, and un-fusing
119
+ # can change the number of calls, don't bother to try to reverse the iteration.
120
+
121
+ function split_bc_pullbacks (cfg:: RCR , f:: F , args:: Vararg{Any,N} ) where {F,N}
122
+ _print (" split_bc_generic" , f, N)
123
+ ys3, backs = tuplecast (args... ) do a...
124
+ rrule_via_ad (cfg, f, a... )
125
+ end
126
+ function back_generic (dys)
127
+ deltas = tuplecast (backs, unthunk (dys)) do back, dy # (could be map, sizes match)
128
+ map (unthunk, back (dy))
67
129
end
68
- back_3 (z:: AbstractZero ) = (NoTangent (), NoTangent (), map (Returns (z), args)... )
69
- return ys3, back_3
130
+ dargs = map (unbroadcast, args, Base. tail (deltas))
131
+ df = ProjectTo (f)(sum (first (deltas)))
132
+ return (NoTangent (), df, dargs... )
70
133
end
134
+ back_generic (z:: AbstractZero ) = (NoTangent (), NoTangent (), map (Returns (z), args)... )
135
+ return ys3, back_generic
71
136
end
72
137
73
138
# Don't run broadcasting on scalars
@@ -158,8 +223,8 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast,
158
223
dz = unthunk (dz_raw)
159
224
dx = @thunk unbroadcast (x, dz ./ conj .(y))
160
225
# dy = @thunk -LinearAlgebra.dot(z, dz) / conj(y) # the reason to be eager is to allow dot here
161
- dy = @thunk - sum (Broadcast. instantiate (broadcasted (* , broadcasted (conj, z), dz))) / conj (y) # complete sum is fast?
162
- (NoTangent (), NoTangent (), dx, dy)
226
+ dy = @thunk - sum (Broadcast. instantiate (broadcasted (* , broadcasted (conj, z), dz))) / conj (y) # complete sum is fast
227
+ return (NoTangent (), NoTangent (), dx, dy)
163
228
end
164
229
return z, bc_divide_back
165
230
end
@@ -234,6 +299,13 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:
234
299
return broadcasted (imag, x), bc_imag_back_2
235
300
end
236
301
302
+ function rrule (:: RCR , :: typeof (broadcasted), :: typeof (complex), x:: NumericOrBroadcast )
303
+ _print (" bc complex" )
304
+ bc_complex_back (dz) = (NoTangent (), NoTangent (), @thunk (unbroadcast (x, unthunk (dz))))
305
+ return broadcasted (complex, x), bc_complex_back
306
+ end
307
+ rrule (:: RCR , :: typeof (broadcasted), :: typeof (complex), x:: Number ) = rrule (complex, x) |> _prepend_zero
308
+
237
309
# ####
238
310
# #### Shape fixing
239
311
# ####
@@ -259,7 +331,7 @@ function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
259
331
else
260
332
sum (dx; dims= 2 : ndims (dx))
261
333
end
262
- ProjectTo (x)(NTuple {length(x)} (val)) # Tangent
334
+ return ProjectTo (x)(NTuple {length(x)} (val)) # Tangent
263
335
end
264
336
265
337
unbroadcast (f:: Function , df) = sum (df)
0 commit comments