Skip to content

Commit d34a524

Browse files
committed
Fix test_approx() when one of the arguments is a broadcasted multidimensional array.
`collect()` on broadcasted arrays doesn't preserve its shape, but instead creates a flat array. This breaks tests e.g. for most activation functions in NNlib: ``` test_rrule(Broadcast.broadcasted, NNlib.σ, rand(3, 4)) ```
1 parent e30fcb4 commit d34a524

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/check_result.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ function test_approx(actual::A, expected::E, msg=""; kwargs...) where {A,E}
142142
if _can_pass_early(actual, expected)
143143
@test true
144144
else
145-
c_actual = collect(actual)
146-
c_expected = collect(expected)
145+
c_actual = collect(Broadcast.materialize(actual))
146+
c_expected = collect(Broadcast.materialize(expected))
147147
if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow
148148
throw(MethodError, test_approx, (actual, expected))
149149
end

test/check_result.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ end
3636

3737
test_approx([1.0, 2.0], [1.0, 2.0])
3838
test_approx([[1.0], [2.0]], [[1.0], [2.0]])
39+
test_approx(Broadcast.broadcasted(identity, [1.0 2.0; 3.0 4.0]), [1.0 2.0; 3.0 4.0])
3940

4041
test_approx(@thunk(10 * 0.1 * [[1.0], [2.0]]), [[1.0], [2.0]])
4142

0 commit comments

Comments
 (0)