Skip to content

Commit 9ab71f7

Browse files
committed
Fix Diagonal tests
1 parent 7465575 commit 9ab71f7

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

test/layers/basic.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,18 @@ import Flux: activations
8989

9090
@testset "Diagonal" begin
9191
@test length(Flux.Diagonal(10)(randn(10))) == 10
92-
@test length(Flux.Diagonal(10)(1)) == 10
9392
@test length(Flux.Diagonal(10)(randn(1))) == 10
9493
@test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10
9594
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
9695

9796
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
98-
@test Flux.Diagonal(2)([1,2]) == [1,2]
97+
@test Flux.Diagonal(2)([1, 2]) == [1, 2]
9998
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]
10099

101-
@test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4)
102-
@test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4)
103-
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2,3,4)) |> size == (2, 3, 4)
104-
@test Flux.Diagonal(2, 3; bias = false)(rand(2,1,4)) |> size == (2, 3, 4)
100+
@test Flux.Diagonal(2)(rand(2, 3, 4)) |> size == (2, 3, 4)
101+
@test Flux.Diagonal(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4)
102+
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4)
103+
@test Flux.Diagonal(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4)
105104
end
106105

107106
@testset "Maxout" begin

0 commit comments

Comments
 (0)