@@ -225,19 +225,20 @@ end
225
225
226
226
function rrule (:: typeof (cumprod), x:: AbstractVector{<:Real} ; dims:: Integer = 1 )
227
227
y = cumprod (x; dims= dims) # does nothing unless dims == 1
228
+ project_x = ProjectTo (x)
228
229
function cumprod_pullback_1 (dy)
229
230
dx_thunk = InplaceableThunk (
230
- @thunk if dims == 1
231
- ∇cumprod (x, dy, y)
232
- else
233
- dy
234
- end
235
- ,
236
231
dx -> if dims == 1
237
232
∇cumprod! (dx, x, dy, y)
238
233
else
239
234
dx .+ = dy
240
235
end
236
+ ,
237
+ @thunk project_x (if dims == 1
238
+ ∇cumprod (x, dy, y)
239
+ else
240
+ dy
241
+ end )
241
242
)
242
243
return (NO_FIELDS, dx_thunk)
243
244
end
@@ -246,21 +247,22 @@ end
246
247
247
248
function rrule (:: typeof (cumprod), x:: AbstractArray{<:Real} ; dims:: Integer )
248
249
y = cumprod (x; dims= dims)
250
+ project_x = ProjectTo (x)
249
251
function cumprod_pullback_2 (dy)
250
252
dx_thunk = InplaceableThunk (
251
- @thunk if dims <= ndims (x)
252
- vald = Val (Int (dims))
253
- ∇cumprod_dim (vald, x, dy, y)
254
- else
255
- dy
256
- end
257
- ,
258
253
dx -> if dims <= ndims (x)
259
254
vald = Val (Int (dims))
260
255
∇cumprod_dim! (dx, vald, x, dy, y)
261
256
else
262
257
dx .+ = dy
263
258
end
259
+ ,
260
+ @thunk project_x (if dims <= ndims (x)
261
+ vald = Val (Int (dims))
262
+ ∇cumprod_dim (vald, x, dy, y)
263
+ else
264
+ dy
265
+ end )
264
266
)
265
267
return (NO_FIELDS, dx_thunk)
266
268
end
0 commit comments