Skip to content

Commit c5db487

Browse files
committed
tidy & fix tests
1 parent 4b4c347 commit c5db487

File tree

2 files changed

+10
-23
lines changed

2 files changed

+10
-23
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -247,16 +247,17 @@ end
247247
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims)
248248
y = cumprod(x; dims=dims)
249249
@assert dims isa Integer
250-
vald = Val(Int(dims)) # else ∇cumprod_dim! will be type unstable
251250
function cumprod_pullback_2(dy)
252251
dx_thunk = InplaceableThunk(
253252
@thunk if dims <= ndims(x)
253+
vald = Val(Int(dims))
254254
∇cumprod_dim(vald, x, dy, y)
255255
else
256256
dy
257257
end
258258
,
259259
dx -> if dims <= ndims(x)
260+
vald = Val(Int(dims))
260261
∇cumprod_dim!(dx, vald, x, dy, y)
261262
else
262263
dx .+= dy
@@ -282,19 +283,6 @@ function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y)
282283
return dx
283284
end
284285

285-
#=
286-
287-
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1]
288-
86.333 μs (2007 allocations: 67.54 KiB)
289-
290-
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with 1 hard-coded
291-
5.417 μs (6 allocations: 15.95 KiB)
292-
293-
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with Val(dim)
294-
5.423 μs (6 allocations: 15.95 KiB)
295-
296-
=#
297-
298286
function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
299287
T = promote_type(eltype(x), eltype(dy))
300288
dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices

test/rulesets/Base/mapreduce.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
dy = sum(x; dims=dims)
3232
ddy = rrule(ChainRules._unsum, x, dy, dims)[2](x)[3]
3333
@test size(ddy) == size(dy)
34-
end
3534
end
3635

3736
@testset "sum abs2" begin
@@ -155,37 +154,37 @@
155154
end # prod
156155
end
157156

158-
159-
157+
@testset "Accumulations" begin
160158
@testset "cumprod" begin
161-
v = randn(9)
159+
v = round.(10 .* randn(9), sigdigits=3)
162160
test_rrule(cumprod, v)
163161
v[3] = 0
164162
test_rrule(cumprod, v)
165163
v[6] = 0
166164
test_rrule(cumprod, v)
167165

168166
@testset "higher dimensions, dims=$dims" for dims in (1,2,3)
169-
m = rand(4,5)
167+
m = round.(10 .* randn(4,5), sigdigits=3)
170168
test_rrule(cumprod, m; fkwargs=(;dims=dims))
171169
m[2,2] = 0
172170
m[2,4] = 0
173171
test_rrule(cumprod, m; fkwargs=(;dims=dims))
174172

175-
t = randn(3,3,3)
176-
test_rrule(cumprod, x; fkwargs=(;dims=dims))
173+
t = round.(10 .* randn(3,3,3), sigdigits=3)
174+
test_rrule(cumprod, t; fkwargs=(;dims=dims))
177175
end
178176

179177
@testset "types" begin
180-
back = unthunk(rrule(cumprod, [1, 2, 3])[2])
178+
back = unthunk(rrule(cumprod, [1, 2, 3])[2]) # allow integer input
181179
@test unthunk(back(fill(0.5, 3))[2]) == [9/2, 2, 1]
182180

183181
back = unthunk(rrule(cumprod, PermutedDimsArray([1 2; 3 4], (2,1)); dims=1)[2])
184182
@test unthunk(back(ones(Float32, 2,2))[2]) == [3 5; 1 3]
185183

186-
@test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1)
184+
@test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) # forward pass fails
187185

188186
back = unthunk(rrule(cumprod, Diagonal([1, 2]); dims=1)[2])
189187
@test unthunk(back(fill(0.5, 2, 2))[2]) [1/2 3/2; 1/2 0]
190188
end
191189
end
190+
end

0 commit comments

Comments
 (0)