Skip to content

Commit c1102fb

Browse files
authored
Merge pull request #192 from JuliaDiff/ox/projrand
use ProjectTo for rand_tangent, and also catch some more places we shoud NoTangent
2 parents 317b81e + 9eeb654 commit c1102fb

File tree

4 files changed

+41
-22
lines changed

4 files changed

+41
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.10.12, 1"
14+
ChainRulesCore = "1"
1515
Compat = "3"
1616
FiniteDifferences = "0.12.12"
1717
julia = "1"

src/rand_tangent.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,21 @@ end
2929
# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
3030
rand_tangent(rng::AbstractRNG, ::BigFloat) = round(big(9 * randn(rng)), sigdigits=5, base=2)
3131

32-
rand_tangent(rng::AbstractRNG, x::StridedArray{T, 0}) where {T} = fill(rand_tangent(x[1]))
33-
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
34-
rand_tangent(rng::AbstractRNG, x::Adjoint) = adjoint(rand_tangent(rng, parent(x)))
35-
rand_tangent(rng::AbstractRNG, x::Transpose) = transpose(rand_tangent(rng, parent(x)))
3632

37-
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
38-
return Tangent{T}(rand_tangent.(Ref(rng), x)...)
39-
end
33+
rand_tangent(rng::AbstractRNG, x::Array{<:Any, 0}) = _compress_notangent(fill(rand_tangent(rng, x[])))
34+
rand_tangent(rng::AbstractRNG, x::Array) = _compress_notangent(rand_tangent.(Ref(rng), x))
4035

41-
function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
42-
return Tangent{T}(; map(x -> rand_tangent(rng, x), xs)...)
36+
# All other AbstractArray's can be handled using the ProjectTo mechanics.
37+
# and follow the same requirements
38+
function rand_tangent(rng::AbstractRNG, x::AbstractArray)
39+
return _compress_notangent(ProjectTo(x)(rand_tangent(rng, collect(x))))
4340
end
4441

42+
# TODO: arguably ProjectTo should handle this for us for AbstactArrays
43+
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/410
44+
_compress_notangent(::AbstractArray{NoTangent}) = NoTangent()
45+
_compress_notangent(x) = x
46+
4547
function rand_tangent(rng::AbstractRNG, x::T) where {T}
4648
if !isstructtype(T)
4749
throw(ArgumentError("Non-struct types are not supported by this fallback."))
@@ -54,8 +56,12 @@ function rand_tangent(rng::AbstractRNG, x::T) where {T}
5456
if all(tangent isa NoTangent for tangent in tangents)
5557
# if none of my fields can be perturbed then I can't be perturbed
5658
return NoTangent()
59+
end
60+
61+
if T <: Tuple
62+
return Tangent{T}(tangents...)
5763
else
58-
Tangent{T}(; NamedTuple{field_names}(tangents)...)
64+
return Tangent{T}(; NamedTuple{field_names}(tangents)...)
5965
end
6066
end
6167

test/rand_tangent.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,44 +39,54 @@ struct Bar
3939
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
4040
([randn(5, 4), 4.0], Vector{Any}),
4141

42-
# Wrapper Arrays
43-
(randn(5, 4)', Adjoint{Float64, Matrix{Float64}}),
44-
(transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}),
45-
42+
# Co-Arrays
43+
(randn(5)', Adjoint{Float64, Vector{Float64}}), # row-vector: special
44+
(randn(5, 4)', Matrix{Float64}), # matrix: generic dense
45+
46+
(transpose(randn(5)), Transpose{Float64, Vector{Float64}}), # row-vector: special
47+
(transpose(randn(5, 4)), Matrix{Float64}), # matrix: generic dense
48+
49+
# AbstactArrays of non-perturbable types
50+
(1:10, NoTangent),
51+
(1:2:10, NoTangent),
52+
([false, true], NoTangent),
4653

4754
# Tuples.
4855
((4.0, ), Tangent{Tuple{Float64}}),
4956
((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}),
57+
((false, true), NoTangent),
58+
(Tuple{}(), NoTangent),
5059

5160
# NamedTuples.
5261
((a=4.0, ), Tangent{NamedTuple{(:a,), Tuple{Float64}}}),
5362
((a=5.0, b=1), Tangent{NamedTuple{(:a, :b), Tuple{Float64, Int}}}),
63+
((a=false, b=true), NoTangent),
64+
((;), NoTangent),
5465

5566
# structs.
5667
(Bar(5.0, 4, rand(rng, 3)), Tangent{Bar}),
5768
(Bar(4.0, 3, Bar(5.0, 2, 4)), Tangent{Bar}),
5869
(sin, NoTangent),
5970
# all fields NoTangent implies NoTangent
6071
(Pair(:a, "b"), NoTangent),
61-
(1:10, NoTangent),
62-
(1:2:10, NoTangent),
72+
(CartesianIndex(2, 3), NoTangent),
6373

64-
# LinearAlgebra types (also just structs).
74+
# LinearAlgebra types
6575
(
6676
UpperTriangular(randn(3, 3)),
67-
Tangent{UpperTriangular{Float64, Matrix{Float64}}},
77+
UpperTriangular{Float64, Matrix{Float64}},
6878
),
6979
(
7080
Diagonal(randn(2)),
71-
Tangent{Diagonal{Float64, Vector{Float64}}},
81+
Diagonal{Float64, Vector{Float64}},
7282
),
7383
(
7484
Symmetric(randn(2, 2)),
75-
Tangent{Symmetric{Float64, Matrix{Float64}}},
85+
Symmetric{Float64, Matrix{Float64}},
7686
),
7787
(
7888
Hermitian(randn(ComplexF64, 1, 1)),
79-
Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}},
89+
Hermitian{ComplexF64, Matrix{ComplexF64}},
8090
),
8191
]
8292
@test rand_tangent(rng, x) isa T_tangent
@@ -96,6 +106,7 @@ struct Bar
96106

97107
# Julia 1.6 changed to using Ryu printing algorithm and seems better at printing short
98108
VERSION > v"1.6" && @testset "niceness of printing" begin
109+
rng = MersenneTwister()
99110
for i in 1:50
100111
@test length(string(rand_tangent(1.0))) <= 6
101112
@test length(string(rand_tangent(1.0 + 1.0im))) <= 12
@@ -104,3 +115,4 @@ struct Bar
104115
end
105116
end
106117
end
118+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ChainRulesCore
22
using ChainRulesTestUtils
3+
using ChainRulesTestUtils: rand_tangent
34
using FiniteDifferences
45
using LinearAlgebra
56
using Random

0 commit comments

Comments
 (0)