Skip to content

Commit b197e62

Browse files
authored
Rules for filter (#570)
* rules for filter * allow tuples * allow tuples for reverse too * fix tests, or mark skipped * one unrelated line nearby needs cleaning up * rule for getindex(Tuple) * remove rules for Tuple * comments, versions * simplify, version * don't use random numbers
1 parent 15b1452 commit b197e62

File tree

5 files changed

+78
-31
lines changed

5 files changed

+78
-31
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.19"
3+
version = "1.20"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[compat]
14-
ChainRulesCore = "1.11"
14+
ChainRulesCore = "1.11.5"
1515
ChainRulesTestUtils = "1"
1616
Compat = "3.35"
1717
FiniteDifferences = "0.12.20"

src/rulesets/Base/array.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,11 @@ end
321321
# 1-dim case allows start/stop, N-dim case takes dims keyword
322322
# whose defaults changed in Julia 1.6... just pass them all through:
323323

324-
function frule((_, xdot), ::typeof(reverse), x::AbstractArray, args...; kw...)
324+
function frule((_, xdot), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
325325
return reverse(x, args...; kw...), reverse(xdot, args...; kw...)
326326
end
327327

328-
function rrule(::typeof(reverse), x::AbstractArray, args...; kw...)
328+
function rrule(::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
329329
nots = map(Returns(NoTangent()), args)
330330
function reverse_pullback(dy)
331331
dx = @thunk reverse(unthunk(dy), args...; kw...)
@@ -360,12 +360,31 @@ function frule((_, xdot), ::typeof(fill), x::Any, dims...)
360360
end
361361

362362
function rrule(::typeof(fill), x::Any, dims...)
363-
project = x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity
363+
project = ProjectTo(x)
364364
nots = map(Returns(NoTangent()), dims)
365365
fill_pullback(Ȳ) = (NoTangent(), project(sum(Ȳ)), nots...)
366366
return fill(x, dims...), fill_pullback
367367
end
368368

369+
#####
370+
##### `filter`
371+
#####
372+
373+
function frule((_, _, xdot), ::typeof(filter), f, x::AbstractArray)
374+
inds = findall(f, x)
375+
return x[inds], xdot[inds]
376+
end
377+
378+
function rrule(::typeof(filter), f, x::AbstractArray)
379+
inds = findall(f, x)
380+
y, back = rrule(getindex, x, inds)
381+
function filter_pullback(dy)
382+
_, dx, _ = back(dy)
383+
return (NoTangent(), NoTangent(), dx)
384+
end
385+
return y, filter_pullback
386+
end
387+
369388
#####
370389
##### `findmax`, `maximum`, etc.
371390
#####

src/rulesets/Base/indexing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...)
2626

2727
return y, getindex_pullback
2828
end
29+

test/rulesets/Base/array.jl

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -164,32 +164,38 @@ end
164164
end
165165

166166
@testset "reverse" begin
167-
# Forward
168-
test_frule(reverse, rand(5))
169-
test_frule(reverse, rand(5), 2, 4)
170-
test_frule(reverse, rand(5), fkwargs=(dims=1,))
171-
172-
test_frule(reverse, rand(3,4), fkwargs=(dims=2,))
173-
if VERSION >= v"1.6"
174-
test_frule(reverse, rand(3,4))
175-
test_frule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))
167+
@testset "Tuple" begin
168+
test_frule(reverse, Tuple(rand(10)))
169+
@test_skip test_rrule(reverse, Tuple(rand(10))) # Ambiguity in isapprox, https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/229
176170
end
177-
178-
# Reverse
179-
test_rrule(reverse, rand(5))
180-
test_rrule(reverse, rand(5), 2, 4)
181-
test_rrule(reverse, rand(5), fkwargs=(dims=1,))
182-
183-
test_rrule(reverse, rand(3,4), fkwargs=(dims=2,))
184-
if VERSION >= v"1.6"
185-
test_rrule(reverse, rand(3,4))
186-
test_rrule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))
187-
188-
# Structured
189-
y, pb = rrule(reverse, Diagonal([1,2,3]))
190-
# We only preserve structure in this case if given structured tangent (no ProjectTo)
191-
@test unthunk(pb(Diagonal([1.1, 2.1, 3.1]))[2]) isa Diagonal
192-
@test unthunk(pb(rand(3, 3))[2]) isa AbstractArray
171+
@testset "Array" begin
172+
# Forward
173+
test_frule(reverse, rand(5))
174+
test_frule(reverse, rand(5), 2, 4)
175+
test_frule(reverse, rand(5), fkwargs=(dims=1,))
176+
177+
test_frule(reverse, rand(3,4), fkwargs=(dims=2,))
178+
if VERSION >= v"1.6"
179+
test_frule(reverse, rand(3,4))
180+
test_frule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))
181+
end
182+
183+
# Reverse
184+
test_rrule(reverse, rand(5))
185+
test_rrule(reverse, rand(5), 2, 4)
186+
test_rrule(reverse, rand(5), fkwargs=(dims=1,))
187+
188+
test_rrule(reverse, rand(3,4), fkwargs=(dims=2,))
189+
if VERSION >= v"1.6"
190+
test_rrule(reverse, rand(3,4))
191+
test_rrule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))
192+
193+
# Structured
194+
y, pb = rrule(reverse, Diagonal([1,2,3]))
195+
# We only preserve structure in this case if given structured tangent (no ProjectTo)
196+
@test unthunk(pb(Diagonal([1.1, 2.1, 3.1]))[2]) isa Diagonal
197+
@test unthunk(pb(rand(3, 3))[2]) isa AbstractArray
198+
end
193199
end
194200
end
195201

@@ -215,6 +221,27 @@ end
215221
test_rrule(fill, 3.3, (3, 3, 3))
216222
end
217223

224+
@testset "filter" begin
225+
@testset "Array" begin
226+
# Random numbers will confuse finite differencing here, as it may perturb across the boundary.
227+
x5 = [0.0, 1.0, 0.2, 0.9, 0.7]
228+
x34 = Float64[-113 124 -37 12
229+
96 -89 103 119
230+
91 -21 -110 10]
231+
232+
# Forward
233+
test_frule(filter, >(0.5) NoTangent(), x5)
234+
test_frule(filter, <(0), x34)
235+
test_frule(filter, >(100), x5)
236+
237+
# Reverse
238+
test_rrule(filter, >(0.5) NoTangent(), x5) # Without ⊢, MethodError: zero(::Base.Fix2{typeof(>), Float64}) -- https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/231
239+
test_rrule(filter, <(0), x34)
240+
test_rrule(filter, >(100), x5) # fixed in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/534
241+
@test unthunk(rrule(filter, >(100), x5)[2](Int[])[3]) == zero(x5)
242+
end
243+
end
244+
218245
@testset "findmin & findmax" begin
219246
# Forward
220247
test_frule(findmin, rand(10))

test/rulesets/Base/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "getindex" begin
2-
@testset "getindex(::Matrix{<:Number},...)" begin
2+
@testset "getindex(::Matrix{<:Number}, ...)" begin
33
x = [1.0 2.0 3.0; 10.0 20.0 30.0]
44

55
@testset "single element" begin

0 commit comments

Comments
 (0)