Skip to content

Commit 63bbd48

Browse files
authored
Fix tangent type of arrays of (named) tuples from FD (#224)
* Fix tangent type of arrays of (named) tuples from FD * Use `ProjectTo` instead of `_maybe_fix_to_composite` * Update docs/Manifest.toml * Add test
1 parent 394f539 commit 63bbd48

File tree

4 files changed

+57
-27
lines changed

4 files changed

+57
-27
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "1.3.0"
3+
version = "1.3.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "1.2"
14+
ChainRulesCore = "1.11.2"
1515
Compat = "3"
1616
FiniteDifferences = "0.12.12"
1717
julia = "1"

docs/Manifest.toml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1616

1717
[[ChainRulesCore]]
1818
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
19-
git-tree-sha1 = "bdc0937269321858ab2a4f288486cb258b9a0af7"
19+
git-tree-sha1 = "f885e7e7c124f8c92650d61b9477b9ac2ee607dd"
2020
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
21-
version = "1.3.0"
21+
version = "1.11.1"
2222

2323
[[ChainRulesTestUtils]]
2424
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
2525
path = ".."
2626
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
27-
version = "1.2.1"
27+
version = "1.2.4"
2828

2929
[[Compat]]
3030
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
31-
git-tree-sha1 = "727e463cfebd0c7b999bbf3e9e7e16f254b94193"
31+
git-tree-sha1 = "dce3e3fea680869eaa0b774b2e8343e9ff442313"
3232
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
33-
version = "3.34.0"
33+
version = "3.40.0"
3434

3535
[[Dates]]
3636
deps = ["Printf"]
@@ -46,25 +46,25 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
4646

4747
[[DocStringExtensions]]
4848
deps = ["LibGit2"]
49-
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
49+
git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b"
5050
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
51-
version = "0.8.5"
51+
version = "0.8.6"
5252

5353
[[Documenter]]
5454
deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
55-
git-tree-sha1 = "350dced36c11f794c6c4da5dc6493ec894e50c16"
55+
git-tree-sha1 = "f425293f7e0acaf9144de6d731772de156676233"
5656
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
57-
version = "0.27.5"
57+
version = "0.27.10"
5858

5959
[[Downloads]]
6060
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
6161
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
6262

6363
[[FiniteDifferences]]
6464
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
65-
git-tree-sha1 = "9a586f04a21e6945f4cbee0d0fb6aebd7b86aa8f"
65+
git-tree-sha1 = "c56a261e1a5472f20cbd7aa218840fd203243319"
6666
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
67-
version = "0.12.18"
67+
version = "0.12.19"
6868

6969
[[IOCapture]]
7070
deps = ["Logging", "Random"]
@@ -127,9 +127,9 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
127127

128128
[[Parsers]]
129129
deps = ["Dates"]
130-
git-tree-sha1 = "438d35d2d95ae2c5e8780b330592b6de8494e779"
130+
git-tree-sha1 = "ae4bbcadb2906ccc085cf52ac286dc1377dceccc"
131131
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
132-
version = "2.0.3"
132+
version = "2.1.2"
133133

134134
[[Pkg]]
135135
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
@@ -172,9 +172,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
172172

173173
[[StaticArrays]]
174174
deps = ["LinearAlgebra", "Random", "Statistics"]
175-
git-tree-sha1 = "3240808c6d463ac46f1c1cd7638375cd22abbccb"
175+
git-tree-sha1 = "3c76dde64d03699e074ac02eb2e8ba8254d428da"
176176
uuid = "90137ffa-7385-5640-81b9-e52037218182"
177-
version = "1.2.12"
177+
version = "1.2.13"
178178

179179
[[Statistics]]
180180
deps = ["LinearAlgebra", "SparseArrays"]

src/finite_difference_calls.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function _make_jvp_call(fdm, f, y, xs, ẋs, ignores)
2121
ignores = collect(ignores)
2222
all(ignores) && return ntuple(_ -> NoTangent(), length(xs))
2323
sigargs = zip(xs[.!ignores], ẋs[.!ignores])
24-
return _maybe_fix_to_composite(y, jvp(fdm, f2, sigargs...))
24+
return ProjectTo(y)(jvp(fdm, f2, sigargs...))
2525
end
2626

2727
"""
@@ -52,7 +52,7 @@ function _make_j′vp_call(fdm, f, ȳ, xs, ignores)
5252
@assert length(fd) == length(arginds)
5353

5454
for (dx, ind) in zip(fd, arginds)
55-
args[ind] = _maybe_fix_to_composite(xs[ind], dx)
55+
args[ind] = ProjectTo(xs[ind])(dx)
5656
end
5757
return (args...,)
5858
end
@@ -87,10 +87,3 @@ function _wrap_function(f, xs, ignores)
8787
end
8888
return fnew
8989
end
90-
91-
# TODO: remove after https://github.com/JuliaDiff/FiniteDifferences.jl/issues/97
92-
# For functions which return a tuple, FD returns a tuple to represent the differential. Tuple
93-
# is not a natural differential, because it doesn't overload +, so make it a Tangent.
94-
_maybe_fix_to_composite(::P, x::Tuple) where {P} = Tangent{P}(x...)
95-
_maybe_fix_to_composite(::P, x::NamedTuple) where {P} = Tangent{P}(; x...)
96-
_maybe_fix_to_composite(::Any, x) = x

test/testers.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ end
5757
abstract type MySpecialTrait end
5858
struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
5959

60+
# Type-stable derivative for test below
61+
struct FVecOfTuplesPullback{T} end
62+
function (f::FVecOfTuplesPullback{T})(Δ) where {T}
63+
ΔΩ_first, ΔΩ_last = unthunk(Δ)
64+
Δx = map(z -> Tangent{T}(z, ΔΩ_last), ΔΩ_first)
65+
return NoTangent(), Δx
66+
end
67+
6068
@testset "testers.jl" begin
6169
@testset "test_scalar" begin
6270
@testset "Ensure correct rules succeed" begin
@@ -608,7 +616,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
608616
test_rrule(
609617
does_not_accept_thunk_id, [1.0, 2.0]; check_thunked_output_tangent=false
610618
)
611-
@test errors(r"MethodError.*Thunk") do
619+
@test errors(r"MethodError.*Thunk") do
612620
test_rrule(does_not_accept_thunk_id, [1.0, 2.0])
613621
end
614622
end
@@ -736,4 +744,33 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
736744
test_rrule(my_id, 2.0; check_inferred=false)
737745
test_rrule(my_id, 2.0; check_thunked_output_tangent=false)
738746
end
747+
748+
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/224
749+
@testset "vectors of tuples" begin
750+
function f_vec_of_tuples(x::AbstractVector{<:Tuple{<:Any,<:Any}})
751+
return map(first, x), sum(last, x)
752+
end
753+
function ChainRulesCore.frule(
754+
(_, Δx),
755+
::typeof(f_vec_of_tuples),
756+
x::AbstractVector{<:Tuple{<:Any,<:Any}},
757+
)
758+
Ω = f_vec_of_tuples(x)
759+
Ω̄ = Tangent{typeof(Ω)}(f_vec_of_tuples(map(ChainRulesCore.backing, Δx))...)
760+
return Ω, Ω̄
761+
end
762+
function ChainRulesCore.rrule(
763+
::typeof(f_vec_of_tuples),
764+
x::AbstractVector{<:Tuple{<:Any,<:Any}},
765+
)
766+
Ω = f_vec_of_tuples(x)
767+
# We use a functor here to fix type inference
768+
f_vec_of_tuples_pullback = FVecOfTuplesPullback{eltype(x)}()
769+
return Ω, f_vec_of_tuples_pullback
770+
end
771+
772+
x_tuples = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
773+
test_frule(f_vec_of_tuples, x_tuples)
774+
test_rrule(f_vec_of_tuples, x_tuples)
775+
end
739776
end

0 commit comments

Comments
 (0)