Skip to content

Commit b66a0f8

Browse files
committed
use ProjectTo for rand_tangent, and also catch some more places we shoud NoTangent
1 parent 8f13880 commit b66a0f8

File tree

3 files changed

+39
-21
lines changed

3 files changed

+39
-21
lines changed

src/rand_tangent.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,21 @@ end
2929
# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
3030
rand_tangent(rng::AbstractRNG, ::BigFloat) = round(big(9 * randn(rng)), sigdigits=5, base=2)
3131

32-
rand_tangent(rng::AbstractRNG, x::StridedArray{T, 0}) where {T} = fill(rand_tangent(x[1]))
33-
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
34-
rand_tangent(rng::AbstractRNG, x::Adjoint) = adjoint(rand_tangent(rng, parent(x)))
35-
rand_tangent(rng::AbstractRNG, x::Transpose) = transpose(rand_tangent(rng, parent(x)))
3632

37-
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
38-
return Tangent{T}(rand_tangent.(Ref(rng), x)...)
39-
end
33+
rand_tangent(rng::AbstractRNG, x::Array{<:Any, 0}) = _compress_notangent(fill(rand_tangent(x[])))
34+
rand_tangent(rng::AbstractRNG, x::Array) = _compress_notangent(rand_tangent.(Ref(rng), x))
4035

41-
function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
42-
return Tangent{T}(; map(x -> rand_tangent(rng, x), xs)...)
36+
# All other AbstractArray's can be handled using the ProjectTo mechanics.
37+
# and follow the same requirements
38+
function rand_tangent(rng::AbstractRNG, x::AbstractArray)
39+
return _compress_notangent(ProjectTo(x)(rand_tangent(collect(x))))
4340
end
4441

42+
# TODO: arguably ProjectTo should handle this for us for AbstactArrays
43+
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/410
44+
_compress_notangent(::AbstractArray{NoTangent}) = NoTangent()
45+
_compress_notangent(x) = x
46+
4547
function rand_tangent(rng::AbstractRNG, x::T) where {T}
4648
if !isstructtype(T)
4749
throw(ArgumentError("Non-struct types are not supported by this fallback."))
@@ -54,8 +56,12 @@ function rand_tangent(rng::AbstractRNG, x::T) where {T}
5456
if all(tangent isa NoTangent for tangent in tangents)
5557
# if none of my fields can be perturbed then I can't be perturbed
5658
return NoTangent()
59+
end
60+
61+
if T <: Tuple
62+
return Tangent{T}(tangents...)
5763
else
58-
Tangent{T}(; NamedTuple{field_names}(tangents)...)
64+
return Tangent{T}(; NamedTuple{field_names}(tangents)...)
5965
end
6066
end
6167

test/rand_tangent.jl

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,44 +39,53 @@ struct Bar
3939
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
4040
([randn(5, 4), 4.0], Vector{Any}),
4141

42-
# Wrapper Arrays
43-
(randn(5, 4)', Adjoint{Float64, Matrix{Float64}}),
44-
(transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}),
45-
42+
# Co-Arrays
43+
(randn(5)', Adjoint{Float64, Vector{Float64}}), # row-vector: special
44+
(randn(5, 4)', Matrix{Float64}), # matrix: generic dense
45+
46+
(transpose(randn(5)), Transpose{Float64, Vector{Float64}}), # row-vector: special
47+
(transpose(randn(5, 4)), Matrix{Float64}), # matrix: generic dense
48+
49+
# AbstactArrays of non-perturbable types
50+
(1:10, NoTangent),
51+
(1:2:10, NoTangent),
52+
([false, true], NoTangent),
4653

4754
# Tuples.
4855
((4.0, ), Tangent{Tuple{Float64}}),
4956
((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}),
57+
((false, true), NoTangent),
58+
(Tuple{}(), NoTangent),
5059

5160
# NamedTuples.
5261
((a=4.0, ), Tangent{NamedTuple{(:a,), Tuple{Float64}}}),
5362
((a=5.0, b=1), Tangent{NamedTuple{(:a, :b), Tuple{Float64, Int}}}),
63+
((a=false, b=true), NoTangent),
64+
((;), NoTangent),
5465

5566
# structs.
5667
(Bar(5.0, 4, rand(rng, 3)), Tangent{Bar}),
5768
(Bar(4.0, 3, Bar(5.0, 2, 4)), Tangent{Bar}),
5869
(sin, NoTangent),
5970
# all fields NoTangent implies NoTangent
6071
(Pair(:a, "b"), NoTangent),
61-
(1:10, NoTangent),
62-
(1:2:10, NoTangent),
6372

64-
# LinearAlgebra types (also just structs).
73+
# LinearAlgebra types
6574
(
6675
UpperTriangular(randn(3, 3)),
67-
Tangent{UpperTriangular{Float64, Matrix{Float64}}},
76+
UpperTriangular{Float64, Matrix{Float64}},
6877
),
6978
(
7079
Diagonal(randn(2)),
71-
Tangent{Diagonal{Float64, Vector{Float64}}},
80+
Diagonal{Float64, Vector{Float64}},
7281
),
7382
(
7483
Symmetric(randn(2, 2)),
75-
Tangent{Symmetric{Float64, Matrix{Float64}}},
84+
Symmetric{Float64, Matrix{Float64}},
7685
),
7786
(
7887
Hermitian(randn(ComplexF64, 1, 1)),
79-
Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}},
88+
Hermitian{ComplexF64, Matrix{ComplexF64}},
8089
),
8190
]
8291
@test rand_tangent(rng, x) isa T_tangent
@@ -96,6 +105,7 @@ struct Bar
96105

97106
# Julia 1.6 changed to using Ryu printing algorithm and seems better at printing short
98107
VERSION > v"1.6" && @testset "niceness of printing" begin
108+
rng = MersenneTwister()
99109
for i in 1:50
100110
@test length(string(rand_tangent(1.0))) <= 6
101111
@test length(string(rand_tangent(1.0 + 1.0im))) <= 12
@@ -104,3 +114,4 @@ struct Bar
104114
end
105115
end
106116
end
117+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ChainRulesCore
22
using ChainRulesTestUtils
3+
using ChainRulesTestUtils: rand_tangent
34
using FiniteDifferences
45
using LinearAlgebra
56
using Random

0 commit comments

Comments
 (0)