@@ -72,30 +72,74 @@ function rrule(
72
72
end
73
73
74
74
function rrule (
75
- config:: RuleConfig{>:HasReverseMode} , :: typeof (sum), f, xs:: AbstractArray ; dims= :
76
- )
77
- fx_and_pullbacks = map (x-> rrule_via_ad (config, f, x), xs)
78
- y = sum (first, fx_and_pullbacks; dims= dims)
75
+ config:: RuleConfig{>:HasReverseMode} ,
76
+ :: typeof (sum),
77
+ f:: F ,
78
+ xs:: AbstractArray{T} ;
79
+ dims = :,
80
+ ) where {F,T}
81
+ project = ProjectTo (xs)
79
82
80
- pullbacks = last .(fx_and_pullbacks)
83
+ if _uses_input_only (f, T)
84
+ # Then we can compute the forward pass as usual, save nothing but `xs`:
85
+ function sum_pullback_f1 (dy)
86
+ dxs = broadcast (unthunk (dy), xs) do dyₖ, xᵢ
87
+ ∂yₖ∂xᵢ = only (only (derivatives_given_output (nothing , f, xᵢ)))
88
+ dyₖ * conj (∂yₖ∂xᵢ)
89
+ end
90
+ return (NoTangent (), NoTangent (), project (dxs))
91
+ end
92
+ return sum (f, xs; dims), sum_pullback_f1
93
+ end
81
94
82
- project = ProjectTo (xs)
95
+ # (There is an intermediate case, where `derivatives_given_output` needs to
96
+ # see `f.(xs)` but we don't need the pullbacks. Not implemented at present.)
97
+
98
+ # In the general case, we need to save all the pullbacks:
99
+ fx_and_pullbacks = map (xᵢ -> rrule_via_ad (config, f, xᵢ), xs)
100
+ y = sum (first, fx_and_pullbacks; dims)
101
+
102
+ function sum_pullback_f2 (dy)
103
+ # For arrays of arrays, we ought to protect the element against broadcasting:
104
+ broadcast_dy = dims isa Colon ? Ref (unthunk (dy)) : unthunk (dy)
105
+ if Base. issingletontype (F)
106
+ # Then at least `f` has no gradient.
107
+ # Broadcasting here gets the shape right with or without `dims` keyword.
108
+ dxs = broadcast (fx_and_pullbacks, broadcast_dy) do (_, pbᵢ), dyₖ
109
+ unthunk (last (pbᵢ (dyₖ)))
110
+ end
111
+ return (NoTangent (), NoTangent (), project (dxs))
83
112
84
- function sum_pullback (ȳ)
85
- call (f, x) = f (x)
86
- # if dims is :, then need only left-handed only broadcast
87
- broadcast_ȳ = dims isa Colon ? (ȳ,) : ȳ
88
- f̄_and_x̄s = call .(pullbacks, broadcast_ȳ)
89
- # no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
90
- f̄ = if fieldcount (typeof (f)) === 0 # Then don't need to worry about derivative wrt f
91
- NoTangent ()
92
113
else
93
- sum (first, f̄_and_x̄s)
114
+ # Most general case. If `f` were stateful, we would need to reverse the order
115
+ # of iteration here, but since this function makes no guarantee, even the primal
116
+ # result is then ill-defined.
117
+ df_and_dxs = broadcast (fx_and_pullbacks, broadcast_dy) do (_, pbᵢ), dyₖ
118
+ pbᵢ (dyₖ)
119
+ end
120
+ df = sum (first, df_and_dxs)
121
+ dxs = map (unthunk ∘ last, df_and_dxs)
122
+ return (NoTangent (), df, project (dxs))
94
123
end
95
- x̄s = map (unthunk ∘ last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
96
- return NoTangent (), f̄, project (x̄s)
97
124
end
98
- return y, sum_pullback
125
+ return y, sum_pullback_f2
126
+ end
127
+
128
+ """
129
+ _uses_input_only(f, xT::Type)
130
+
131
+ Returns `true` if it can prove that `derivatives_given_output` will work using only the input
132
+ of the given type. Thus there is no need to store the output `y = f(x::xT)`, allowing us to take
133
+ a fast path in the `rrule` for `sum(f, xs)`.
134
+
135
+ Works by seeing if the result of `derivatives_given_output(nothing, f, x)` can be inferred.
136
+ The method of `derivatives_given_output` usually comes from `@scalar_rule`.
137
+ """
138
+ function _uses_input_only (f:: F , :: Type{xT} ) where {F,xT}
139
+ gT = Core. Compiler. _return_type (derivatives_given_output, Tuple{Nothing, F, xT})
140
+ # Here we must check `<: Number`, to avoid this, the one rule which can return the `nothing`:
141
+ # ChainRules.derivatives_given_output("anything", exp, 1) == (("anything",),)
142
+ return isconcretetype (gT) && gT <: Tuple{Tuple{Number}}
99
143
end
100
144
101
145
# https://github.com/JuliaDiff/ChainRules.jl/issues/522
@@ -228,6 +272,7 @@ function ∇prod_dims(vald::Val{dims}, x, dy, y=prod(x; dims=dims)) where {dims}
228
272
∇prod_dims! (dx, vald, x, dy, y)
229
273
return dx
230
274
end
275
+ ∇prod_dims (:: Val , x, dy:: AbstractZero , y= 0 ) = dy
231
276
232
277
function ∇prod_dims! (dx, :: Val{dims} , x, dy, y) where {dims}
233
278
iters = ntuple (d -> d in dims ? tuple (:) : axes (x,d), ndims (x)) # Without Val(dims) this is a serious type instability
@@ -244,6 +289,7 @@ function ∇prod(x, dy::Number=1, y::Number=prod(x))
244
289
∇prod! (dx, x, dy, y)
245
290
return dx
246
291
end
292
+ ∇prod (x, dy:: AbstractZero , y:: Number = 0 ) = dy
247
293
248
294
function ∇prod! (dx, x, dy:: Number = 1 , y:: Number = prod (x))
249
295
numzero = iszero (y) ? count (iszero, x) : 0
@@ -326,7 +372,8 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y
326
372
dx = fill! (similar (x, T, axes (x)), zero (T))
327
373
∇cumprod_dim! (dx, vald, x, dy, y)
328
374
return dx
329
- end
375
+ end
376
+ ∇cumprod_dim (vald:: Val , x:: AbstractArray , dy:: AbstractZero , y= 0 ) = dy
330
377
331
378
@inline function ∇cumprod_dim! (dx:: AbstractArray , :: Val{dim} , x:: AbstractArray , dy, y) where {dim}
332
379
iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x))
@@ -342,6 +389,7 @@ function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
342
389
∇cumprod! (dx, x, dy, y)
343
390
return dx
344
391
end
392
+ ∇cumprod (x:: AbstractVector , dy:: AbstractZero , y= 0 ) = dy
345
393
346
394
@inline function ∇cumprod! (dx:: AbstractVector , x:: AbstractVector , dy, y)
347
395
lo, hi = firstindex (x), lastindex (x)
0 commit comments