@@ -218,3 +218,93 @@ function ∇prod_one_zero!(dx, x, dy::Number=1) # Assumes exactly one x is zero
218
218
dx[i_zero] += p_rest * dy
219
219
return
220
220
end
221
+
222
+ # ####
223
+ # #### `cumprod`
224
+ # ####
225
+
226
+ function rrule (:: typeof (cumprod), x:: AbstractVector{<:Real} ; dims= 1 )
227
+ y = cumprod (x; dims= dims) # does nothing unless dims == 1
228
+ function cumprod_pullback_1 (dy)
229
+ dx_thunk = InplaceableThunk (
230
+ @thunk if dims == 1
231
+ ∇cumprod (x, dy, y)
232
+ else
233
+ dy
234
+ end
235
+ ,
236
+ dx -> if dims == 1
237
+ ∇cumprod! (dx, x, dy, y)
238
+ else
239
+ dx .+ = dy
240
+ end
241
+ )
242
+ return (NO_FIELDS, dx_thunk)
243
+ end
244
+ return y, cumprod_pullback_1
245
+ end
246
+
247
+ function rrule (:: typeof (cumprod), x:: AbstractArray{<:Real} ; dims)
248
+ y = cumprod (x; dims= dims)
249
+ @assert dims isa Integer
250
+ # vald = Val(dims)
251
+ function cumprod_pullback_2 (dy)
252
+ dx_thunk = InplaceableThunk (
253
+ @thunk if dims <= ndims (x)
254
+ ∇cumprod_dim (dims, x, dy, y)
255
+ else
256
+ dy
257
+ end
258
+ ,
259
+ dx -> if dims <= ndims (x)
260
+ ∇cumprod_dim! (dx, dims, x, dy, y)
261
+ else
262
+ dx .+ = dy
263
+ end
264
+ )
265
+ return (NO_FIELDS, dx_thunk)
266
+ end
267
+ return y, cumprod_pullback_2
268
+ end
269
+
270
+ function ∇cumprod_dim (dim:: Integer , x:: AbstractArray , dy= fill! (zero (x),1 ), y= cumprod (x; dims= dim))
271
+ T = promote_type (eltype (x), eltype (dy))
272
+ dx = fill! (similar (x, T, axes (x)), zero (T))
273
+ ∇cumprod_dim! (dx, dim, x, dy, y)
274
+ return dx
275
+ end
276
+
277
+ function ∇cumprod_dim! (dx:: AbstractArray , dim:: Integer , x:: AbstractArray , dy, y)
278
+ iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x)) # type instability!
279
+ for ind in Iterators. product (iters... )
280
+ @views ∇cumprod! (dx[ind... ], x[ind... ], dy[ind... ], y[ind... ])
281
+ end
282
+ return dx
283
+ end
284
+
285
+ function ∇cumprod (x:: AbstractVector , dy= one (x), y= cumprod (x))
286
+ T = promote_type (eltype (x), eltype (dy))
287
+ dx = fill! (similar (x, T, axes (x)), zero (T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
288
+ ∇cumprod! (dx, x, dy, y)
289
+ return dx
290
+ end
291
+
292
+ function ∇cumprod! (dx:: AbstractVector , x:: AbstractVector , dy, y)
293
+ lo, hi = firstindex (x), lastindex (x)
294
+ z = something (findfirst (iszero, x), hi+ 1 )
295
+ @inbounds for i in lo: z- 1
296
+ ixi = 1 / x[i]
297
+ for k in i: z- 1
298
+ dx[i] += y[k] * dy[k] * ixi
299
+ end
300
+ end
301
+ @inbounds if z != hi+ 1
302
+ yk = z== 1 ? one (eltype (y)) : y[z- 1 ] # will be prod(x[j] for j=1:k if j!=z)
303
+ dx[z] += yk * dy[z]
304
+ for k in (z+ 1 ): hi
305
+ yk *= x[k]
306
+ dx[z] += yk * dy[k]
307
+ end
308
+ end
309
+ return dx
310
+ end
0 commit comments