Skip to content

Commit ffd6ae9

Browse files
committed
Add new caching behavior to fcollect
1 parent e402dad commit ffd6ae9

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ julia = "1.6"
1212

1313
[extras]
1414
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
15+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1516
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[targets]
19-
test = ["Test", "Documenter", "Zygote"]
20+
test = ["Test", "Documenter", "StaticArrays", "Zygote"]

src/walks.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ end
8888

8989
struct NoKeyword end
9090

91-
usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x)
91+
usecache(::Union{AbstractDict, AbstractSet}, x) =
92+
isleaf(x) ? anymutable(x) : ismutable(x)
9293
usecache(::Nothing, x) = false
9394

9495
@generated function anymutable(x::T) where {T}
@@ -149,9 +150,11 @@ CollectWalk() = CollectWalk(Base.IdSet(), Any[])
149150
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
150151
# for the results, to preserve traversal order (important downstream!).
151152
function (walk::CollectWalk)(recurse, x)
152-
x in walk.cache && return walk.output
153+
if usecache(walk.cache, x) && (x in walk.cache)
154+
return walk.output
155+
end
153156
# to exclude, we wrap this walk in ExcludeWalk
154-
push!(walk.cache, x)
157+
usecache(walk.cache, x) && push!(walk.cache, x)
155158
push!(walk.output, x)
156159
map(recurse, children(x))
157160

test/basics.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ end
8585
@test m5f.x === m5f.y
8686
@test m5f.x !== m5f.z
8787

88-
@testset "usecache" begin
89-
d = IdDict()
90-
88+
@testset "usecache ($d)" for d in [IdDict(), Base.IdSet()]
9189
# Leaf types:
9290
@test usecache(d, [1,2])
9391
@test !usecache(d, 4.0)
@@ -101,9 +99,9 @@ end
10199

102100
@test !usecache(d, (5, [6.0]')) # contains mutable
103101
@test !usecache(d, (x = [1,2,3], y = 4))
104-
102+
105103
usecache(d, OneChild3([1,2], 3, nothing)) # mutable isn't a child, do we care?
106-
104+
107105
# No dictionary:
108106
@test !usecache(nothing, [1,2])
109107
@test !usecache(nothing, 3)
@@ -173,6 +171,14 @@ end
173171
m2 = [1, 2, 3]
174172
m3 = Foo(m1, m2)
175173
@test all(fcollect(m3) .=== [m3, m1, m2])
174+
175+
m1 = [1, 2, 3]
176+
m2 = SVector{length(m1)}(m1)
177+
m2′ = SVector{length(m1)}(m1)
178+
m3 = Foo(m1, m1)
179+
m4 = Foo(m2, m2′)
180+
@test all(fcollect(m3) .=== [m3, m1])
181+
@test all(fcollect(m4) .=== [m4, m2, m2′])
176182
end
177183

178184
###

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Functors, Test
22
using Zygote
33
using LinearAlgebra
4+
using StaticArrays
45

56
@testset "Functors.jl" begin
67

0 commit comments

Comments
 (0)