Skip to content

Commit d2d708e

Browse files
committed
update for 1.0
1 parent e110d74 commit d2d708e

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -225,19 +225,20 @@ end
225225

226226
function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1)
227227
y = cumprod(x; dims=dims) # does nothing unless dims == 1
228+
project_x = ProjectTo(x)
228229
function cumprod_pullback_1(dy)
229230
dx_thunk = InplaceableThunk(
230-
@thunk if dims == 1
231-
∇cumprod(x, dy, y)
232-
else
233-
dy
234-
end
235-
,
236231
dx -> if dims == 1
237232
∇cumprod!(dx, x, dy, y)
238233
else
239234
dx .+= dy
240235
end
236+
,
237+
@thunk project_x(if dims == 1
238+
∇cumprod(x, dy, y)
239+
else
240+
dy
241+
end)
241242
)
242243
return (NO_FIELDS, dx_thunk)
243244
end
@@ -246,21 +247,22 @@ end
246247

247248
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer)
248249
y = cumprod(x; dims=dims)
250+
project_x = ProjectTo(x)
249251
function cumprod_pullback_2(dy)
250252
dx_thunk = InplaceableThunk(
251-
@thunk if dims <= ndims(x)
252-
vald = Val(Int(dims))
253-
∇cumprod_dim(vald, x, dy, y)
254-
else
255-
dy
256-
end
257-
,
258253
dx -> if dims <= ndims(x)
259254
vald = Val(Int(dims))
260255
∇cumprod_dim!(dx, vald, x, dy, y)
261256
else
262257
dx .+= dy
263258
end
259+
,
260+
@thunk project_x(if dims <= ndims(x)
261+
vald = Val(Int(dims))
262+
∇cumprod_dim(vald, x, dy, y)
263+
else
264+
dy
265+
end)
264266
)
265267
return (NO_FIELDS, dx_thunk)
266268
end

0 commit comments

Comments
 (0)