Skip to content

Commit 93e1de7

Browse files
authored
Some small printing upgrades (#2344)
* some printing upgrades * print eltype too * move one line to solve order-of-loading issue * better fix * tests, and Fix1
1 parent c9bab66 commit 93e1de7

File tree

3 files changed

+76
-7
lines changed

3 files changed

+76
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.14.22"
3+
version = "0.14.23"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/layers/show.jl

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ function _macro_big_show(ex)
2020
end
2121

2222
function _big_show(io::IO, obj, indent::Int=0, name=nothing)
23-
pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")")
23+
pre, post = _show_pre_post(obj)
2424
children = _show_children(obj)
2525
if all(_show_leaflike, children)
2626
# This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids,
2727
# but once all layers use @layer, they stop the recursion by defining a method for _big_show.
2828
_layer_show(io, obj, indent, name)
2929
else
30-
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre)
31-
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
32-
# then we insert names -- can this be done more generically?
30+
println(io, " "^indent, isnothing(name) ? "" : "$name = ", pre)
31+
if obj isa Chain{<:NamedTuple} || obj isa NamedTuple
32+
# then we insert names -- can this be done more generically?
3333
for k in Base.keys(obj)
3434
_big_show(io, obj[k], indent+2, k)
3535
end
@@ -52,6 +52,20 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
5252
end
5353
end
5454

55+
for Fix in (:Fix1, :Fix2)
56+
pre = string(Fix, "(")
57+
@eval function _big_show(io::IO, obj::Base.$Fix, indent::Int=0, name=nothing)
58+
println(io, " "^indent, isnothing(name) ? "" : "$name = ", $pre)
59+
_big_show(io, obj.f, indent+2)
60+
_big_show(io, obj.x, indent+2)
61+
println(io, " "^indent, ")", ",")
62+
end
63+
end
64+
65+
_show_pre_post(obj) = string(nameof(typeof(obj)), "("), ")"
66+
_show_pre_post(::AbstractVector) = "[", "]"
67+
_show_pre_post(::NamedTuple) = "(;", ")"
68+
5569
_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
5670

5771
# note the covariance of tuple, using <:T causes warning or error
@@ -88,7 +102,7 @@ end
88102

89103
function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
90104
_str = isnothing(name) ? "" : "$name = "
91-
str = _str * sprint(show, layer, context=io)
105+
str = _str * _layer_string(io, layer)
92106
print(io, " "^indent, str, indent==0 ? "" : ",")
93107
if !isempty(params(layer))
94108
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
@@ -103,6 +117,15 @@ color=:light_black)
103117
indent==0 || println(io)
104118
end
105119

120+
_layer_string(io::IO, layer) = sprint(show, layer, context=io)
121+
# _layer_string(::IO, a::AbstractArray) = summary(layer) # sometimes too long e.g. CuArray
122+
function _layer_string(::IO, a::AbstractArray)
123+
full = string(typeof(a))
124+
comma = findfirst(',', full)
125+
short = isnothing(comma) ? full : full[1:comma] * "...}"
126+
Base.dims2string(size(a)) * " " * short
127+
end
128+
106129
function _big_finale(io::IO, m)
107130
ps = params(m)
108131
if length(ps) > 2
@@ -150,3 +173,43 @@ _any(f, x::Number) = f(x)
150173
# _any(f, x) = false
151174

152175
_all(f, xs) = !_any(!f, xs)
176+
177+
#=
178+
179+
julia> struct Tmp2; x; y; end; Flux.@functor Tmp2
180+
181+
# Before, notice Array(), NamedTuple(), and values
182+
183+
julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
184+
Chain(
185+
Tmp2(
186+
Array(
187+
Dense(2 => 3), # 9 parameters
188+
[0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters
189+
),
190+
NamedTuple(
191+
1:3, # 3 parameters
192+
Dense(3 => 4), # 16 parameters
193+
[0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters
194+
),
195+
),
196+
) # Total: 7 arrays, 43 parameters, 644 bytes.
197+
198+
# After, (; x=, y=, z=) and "3-element Array"
199+
200+
julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
201+
Chain(
202+
Tmp2(
203+
[
204+
Dense(2 => 3), # 9 parameters
205+
4×3 Adjoint, # 12 parameters
206+
],
207+
(;
208+
x = 3-element UnitRange, # 3 parameters
209+
y = Dense(3 => 4), # 16 parameters
210+
z = 3-element Array, # 3 parameters
211+
),
212+
),
213+
) # Total: 7 arrays, 43 parameters, 644 bytes.
214+
215+
=#

test/layers/show.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ end
7171
# Functors@0.3 marks transposed matrices non-leaf, shouldn't affect printing:
7272
adjoint_chain = repr("text/plain", Chain([Dense([1 2; 3 4]')]))
7373
@test occursin("Dense(2 => 2)", adjoint_chain)
74-
@test occursin("Chain([", adjoint_chain)
74+
@test occursin("Chain(", adjoint_chain)
75+
@test occursin("[", adjoint_chain)
76+
77+
# New printing of arrays, and Fix1
78+
fix_chain = repr("text/plain", Chain(Base.Fix1(*, rand32(22,33)), softmax))
79+
@test occursin("Fix1(", fix_chain)
80+
@test occursin("22×33 Matrix{Float32}", fix_chain)
7581
end
7682

7783
# Bug when no children, https://github.com/FluxML/Flux.jl/issues/2208

0 commit comments

Comments
 (0)