@@ -218,3 +218,96 @@ function ∇prod_one_zero!(dx, x, dy::Number=1) # Assumes exactly one x is zero
218
218
dx[i_zero] += p_rest * dy
219
219
return
220
220
end
221
+
222
+ # ####
223
+ # #### `cumprod`
224
+ # ####
225
+
226
+ function rrule (:: typeof (cumprod), x:: AbstractVector{<:Real} ; dims:: Integer = 1 )
227
+ y = cumprod (x; dims= dims) # does nothing unless dims == 1
228
+ project_x = ProjectTo (x)
229
+ function cumprod_pullback_1 (dy_raw)
230
+ dy = unthunk (dy_raw)
231
+ dx_thunk = InplaceableThunk (
232
+ dx -> if dims == 1
233
+ ∇cumprod! (dx, x, dy, y)
234
+ else
235
+ dx .+ = dy
236
+ end
237
+ ,
238
+ @thunk project_x (if dims == 1
239
+ ∇cumprod (x, dy, y)
240
+ else
241
+ dy
242
+ end )
243
+ )
244
+ return (NoTangent (), dx_thunk)
245
+ end
246
+ return y, cumprod_pullback_1
247
+ end
248
+
249
+ function rrule (:: typeof (cumprod), x:: AbstractArray{<:Real} ; dims:: Integer )
250
+ y = cumprod (x; dims= dims)
251
+ project_x = ProjectTo (x)
252
+ function cumprod_pullback_2 (dy_raw)
253
+ dy = unthunk (dy_raw)
254
+ dx_thunk = InplaceableThunk (
255
+ dx -> if dims <= ndims (x)
256
+ vald = Val (Int (dims))
257
+ ∇cumprod_dim! (dx, vald, x, dy, y)
258
+ else
259
+ dx .+ = dy
260
+ end
261
+ ,
262
+ @thunk project_x (if dims <= ndims (x)
263
+ vald = Val (Int (dims))
264
+ ∇cumprod_dim (vald, x, dy, y)
265
+ else
266
+ dy
267
+ end )
268
+ )
269
+ return (NoTangent (), dx_thunk)
270
+ end
271
+ return y, cumprod_pullback_2
272
+ end
273
+
274
+ function ∇cumprod_dim (vald:: Val{dim} , x:: AbstractArray , dy= fill! (zero (x),1 ), y= cumprod (x; dims= dim)) where {dim}
275
+ T = promote_type (eltype (x), eltype (dy))
276
+ dx = fill! (similar (x, T, axes (x)), zero (T))
277
+ ∇cumprod_dim! (dx, vald, x, dy, y)
278
+ return dx
279
+ end
280
+
281
+ @inline function ∇cumprod_dim! (dx:: AbstractArray , :: Val{dim} , x:: AbstractArray , dy, y) where {dim}
282
+ iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x))
283
+ for ind in Iterators. product (iters... )
284
+ @views ∇cumprod! (dx[ind... ], x[ind... ], dy[ind... ], y[ind... ])
285
+ end
286
+ return dx
287
+ end
288
+
289
+ function ∇cumprod (x:: AbstractVector , dy= one (x), y= cumprod (x))
290
+ T = promote_type (eltype (x), eltype (dy)) # really needs to allow dy * y / x
291
+ dx = fill! (similar (x, T, axes (x)), zero (T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
292
+ ∇cumprod! (dx, x, dy, y)
293
+ return dx
294
+ end
295
+
296
+ @inline function ∇cumprod! (dx:: AbstractVector , x:: AbstractVector , dy, y)
297
+ lo, hi = firstindex (x), lastindex (x)
298
+ z = something (findfirst (iszero, x), hi+ 1 )
299
+ acc = zero (eltype (dy))
300
+ @inbounds for k in z- 1 : - 1 : lo
301
+ acc += y[k] * dy[k]
302
+ dx[k] += acc / x[k]
303
+ end
304
+ @inbounds if z != hi+ 1
305
+ yk = z== 1 ? one (eltype (y)) : y[z- 1 ] # will be prod(x[j] for j=1:k if j!=z)
306
+ dx[z] += yk * dy[z]
307
+ for k in (z+ 1 ): hi
308
+ yk *= x[k]
309
+ dx[z] += yk * dy[k]
310
+ end
311
+ end
312
+ return dx
313
+ end
0 commit comments