Skip to content

Commit 86d8bfe

Browse files
authored
Merge pull request #238 from dfdx/dfdx/fix-rand-tangent
Add rand_tangent() for Broadcasted
2 parents b8f4217 + 3d9b3a4 commit 86d8bfe

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "1.5.1"
3+
version = "1.5.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rand_tangent.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ rand_tangent(rng::AbstractRNG, ::BigFloat) = round(big(9 * randn(rng)), digits=5
2424

2525
rand_tangent(rng::AbstractRNG, x::AbstractArray) = ProjectTo(x)(rand_tangent.(Ref(rng), x))
2626

27+
rand_tangent(rng::AbstractRNG, x::Broadcast.Broadcasted) = rand_tangent(Broadcast.materialize(x))
28+
2729
function rand_tangent(rng::AbstractRNG, x::T) where {T}
2830
if !isstructtype(T)
2931
throw(ArgumentError("Non-struct types are not supported by this fallback."))

test/rand_tangent.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct Bar
1717
(4, NoTangent),
1818
(FiniteDifferences, NoTangent), # Module object
1919
# Types (not instances of type)
20-
(Bar, NoTangent),
20+
(Bar, NoTangent),
2121
(Union{Int, Bar}, NoTangent),
2222
(Union{Int, Bar}, NoTangent),
2323
(Vector, NoTangent),
@@ -42,15 +42,18 @@ struct Bar
4242
# Co-Arrays
4343
(randn(5)', Adjoint{Float64, Vector{Float64}}), # row-vector: special
4444
(randn(5, 4)', Matrix{Float64}), # matrix: generic dense
45-
45+
4646
(transpose(randn(5)), Transpose{Float64, Vector{Float64}}), # row-vector: special
4747
(transpose(randn(5, 4)), Matrix{Float64}), # matrix: generic dense
48-
48+
4949
# AbstactArrays of non-perturbable types
5050
(1:10, NoTangent),
5151
(1:2:10, NoTangent),
5252
([false, true], NoTangent),
5353

54+
# Broadcasted.
55+
(Broadcast.broadcasted(sin, randn(5, 4)), Matrix{Float64}),
56+
5457
# Tuples.
5558
((4.0, ), Tangent{Tuple{Float64}}),
5659
((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}),

0 commit comments

Comments
 (0)