Skip to content

Commit 8dc43d1

Browse files
authored
Merge pull request #165 from JuliaDiff/ox/covectors
Make rand_tangent on adjoint an transpose return natural
2 parents ebff52d + f71f77b commit 8dc43d1

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/rand_tangent.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()
1313

1414
rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)
1515

16-
# TODO: right now Julia don't allow `randn(rng, BigFloat)`
16+
# TODO: right now Julia don't allow `randn(rng, BigFloat)`
1717
# see: https://github.com/JuliaLang/julia/issues/17629
1818
rand_tangent(rng::AbstractRNG, ::BigFloat) = big(randn(rng))
1919

2020
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
21+
rand_tangent(rng::AbstractRNG, x::Adjoint) = adjoint(rand_tangent(rng, parent(x)))
22+
rand_tangent(rng::AbstractRNG, x::Transpose) = transpose(rand_tangent(rng, parent(x)))
2123

2224
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
2325
return Tangent{T}(rand_tangent.(Ref(rng), x)...)

test/rand_tangent.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ using FiniteDifferences: rand_tangent
2424
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
2525
([randn(5, 4), 4.0], Vector{Any}),
2626

27+
# Wrapper Arrays
28+
(randn(5, 4)', Adjoint{Float64, Matrix{Float64}}),
29+
(transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}),
30+
31+
2732
# Tuples.
2833
((4.0, ), Tangent{Tuple{Float64}}),
2934
((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}),
@@ -66,20 +71,19 @@ using FiniteDifferences: rand_tangent
6671
Hermitian(randn(ComplexF64, 1, 1)),
6772
Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}},
6873
),
69-
(
70-
Adjoint(randn(ComplexF64, 3, 3)),
71-
Tangent{Adjoint{ComplexF64, Matrix{ComplexF64}}},
72-
),
73-
(
74-
Transpose(randn(3)),
75-
Tangent{Transpose{Float64, Vector{Float64}}},
76-
),
7774
]
7875
@test rand_tangent(rng, x) isa T_tangent
7976
@test rand_tangent(x) isa T_tangent
80-
@test x + rand_tangent(rng, x) isa typeof(x)
8177
end
8278

83-
# Ensure struct fallback errors for non-struct types.
84-
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
79+
@testset "erroring cases" begin
80+
# Ensure struct fallback errors for non-struct types.
81+
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
82+
end
83+
84+
@testset "compsition of addition" begin
85+
x = Foo(1.5, 2, Foo(1.1, 3, [1.7, 1.4, 0.9]))
86+
@test x + rand_tangent(x) isa typeof(x)
87+
@test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x)
88+
end
8589
end

0 commit comments

Comments
 (0)