Skip to content

Commit a988dee

Browse files
committed
fixup
1 parent d2d708e commit a988dee

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ end
226226
function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1)
227227
y = cumprod(x; dims=dims) # does nothing unless dims == 1
228228
project_x = ProjectTo(x)
229-
function cumprod_pullback_1(dy)
229+
function cumprod_pullback_1(dy_raw)
230+
dy = unthunk(dy_raw)
230231
dx_thunk = InplaceableThunk(
231232
dx -> if dims == 1
232233
∇cumprod!(dx, x, dy, y)
@@ -240,15 +241,16 @@ function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1)
240241
dy
241242
end)
242243
)
243-
return (NO_FIELDS, dx_thunk)
244+
return (NoTangent(), dx_thunk)
244245
end
245246
return y, cumprod_pullback_1
246247
end
247248

248249
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer)
249250
y = cumprod(x; dims=dims)
250251
project_x = ProjectTo(x)
251-
function cumprod_pullback_2(dy)
252+
function cumprod_pullback_2(dy_raw)
253+
dy = unthunk(dy_raw)
252254
dx_thunk = InplaceableThunk(
253255
dx -> if dims <= ndims(x)
254256
vald = Val(Int(dims))
@@ -264,7 +266,7 @@ function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer)
264266
dy
265267
end)
266268
)
267-
return (NO_FIELDS, dx_thunk)
269+
return (NoTangent(), dx_thunk)
268270
end
269271
return y, cumprod_pullback_2
270272
end

test/rulesets/Base/mapreduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ end
187187
@test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) # forward pass fails, so can't test gradient
188188

189189
back = rrule(cumprod, Diagonal([1, 2]); dims=1)[2]
190-
@test unthunk(back(fill(0.5, 2, 2))[2]) [1/2 3/2; 1/2 0]
190+
@test unthunk(back(fill(0.5, 2, 2))[2]) [1/2 0; 0 0] # ProjectTo'd to Diagonal now
191191
end
192192
end
193193
end

0 commit comments

Comments
 (0)