Skip to content

Commit 0619f15

Browse files
author
Michael Abbott
committed
outputsize test, and simplify printing
1 parent 37c838d commit 0619f15

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/utils.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,7 @@ julia> rand(Float32, 10, 10) |> m |> size
289289
"""
290290
unsqueeze(dim::Integer) = Base.Fix2(unsqueeze, dim)
291291

292-
Base.show(io::IO, u::Base.Fix2{typeof(unsqueeze)}) = print(io, "unsqueeze(", u.x, ")")
293-
Base.show(io::IO, ::MIME"text/plain", u::Base.Fix2{typeof(unsqueeze)}) = show(io, u) # at top level
294-
Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = show(io, u) # within Chain etc.
292+
Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = print(io, "unsqueeze(", u.x, ")")
295293

296294
"""
297295
stack(xs, dim)

test/outputsize.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
m = flatten
2323
@test outputsize(m, (5, 5, 3, 10)) == (75, 10)
2424

25+
m = Flux.unsqueeze(3)
26+
@test outputsize(m, (5, 7, 13)) == (5, 7, 1, 13)
27+
2528
m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10))
2629
@test outputsize(m, (10, 10, 3, 50)) == (10, 50)
2730
@test outputsize(m, (10, 10, 3, 2)) == (10, 2)

0 commit comments

Comments
 (0)