@@ -20,16 +20,16 @@ function _macro_big_show(ex)
20
20
end
21
21
22
22
function _big_show (io:: IO , obj, indent:: Int = 0 , name= nothing )
23
- pre, post = obj isa Chain{ <: AbstractVector } ? ( " ([ " , " ]) " ) : ( " ( " , " ) " )
23
+ pre, post = _show_pre_post (obj )
24
24
children = _show_children (obj)
25
25
if all (_show_leaflike, children)
26
26
# This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids,
27
27
# but once all layers use @layer, they stop the recursion by defining a method for _big_show.
28
28
_layer_show (io, obj, indent, name)
29
29
else
30
- println (io, " " ^ indent, isnothing (name) ? " " : " $name = " , nameof ( typeof (obj)), pre)
31
- if obj isa Chain{<: NamedTuple } && children == getfield ( obj, :layers )
32
- # then we insert names -- can this be done more generically?
30
+ println (io, " " ^ indent, isnothing (name) ? " " : " $name = " , pre)
31
+ if obj isa Chain{<: NamedTuple } || obj isa NamedTuple
32
+ # then we insert names -- can this be done more generically?
33
33
for k in Base. keys (obj)
34
34
_big_show (io, obj[k], indent+ 2 , k)
35
35
end
@@ -52,6 +52,20 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
52
52
end
53
53
end
54
54
55
+ for Fix in (:Fix1 , :Fix2 )
56
+ pre = string (Fix, " (" )
57
+ @eval function _big_show (io:: IO , obj:: Base. $ Fix, indent:: Int = 0 , name= nothing )
58
+ println (io, " " ^ indent, isnothing (name) ? " " : " $name = " , $ pre)
59
+ _big_show (io, obj. f, indent+ 2 )
60
+ _big_show (io, obj. x, indent+ 2 )
61
+ println (io, " " ^ indent, " )" , " ," )
62
+ end
63
+ end
64
+
65
+ _show_pre_post (obj) = string (nameof (typeof (obj)), " (" ), " )"
66
+ _show_pre_post (:: AbstractVector ) = " [" , " ]"
67
+ _show_pre_post (:: NamedTuple ) = " (;" , " )"
68
+
55
69
_show_leaflike (x) = isleaf (x) # mostly follow Functors, except for:
56
70
57
71
# note the covariance of tuple, using <:T causes warning or error
88
102
89
103
function _layer_show (io:: IO , layer, indent:: Int = 0 , name= nothing )
90
104
_str = isnothing (name) ? " " : " $name = "
91
- str = _str * sprint (show , layer, context = io )
105
+ str = _str * _layer_string (io , layer)
92
106
print (io, " " ^ indent, str, indent== 0 ? " " : " ," )
93
107
if ! isempty (params (layer))
94
108
print (io, " " ^ max (2 , (indent== 0 ? 20 : 39 ) - indent - length (str)))
@@ -103,6 +117,15 @@ color=:light_black)
103
117
indent== 0 || println (io)
104
118
end
105
119
120
+ _layer_string (io:: IO , layer) = sprint (show, layer, context= io)
121
+ # _layer_string(::IO, a::AbstractArray) = summary(layer) # sometimes too long e.g. CuArray
122
+ function _layer_string (:: IO , a:: AbstractArray )
123
+ full = string (typeof (a))
124
+ comma = findfirst (' ,' , full)
125
+ short = isnothing (comma) ? full : full[1 : comma] * " ...}"
126
+ Base. dims2string (size (a)) * " " * short
127
+ end
128
+
106
129
function _big_finale (io:: IO , m)
107
130
ps = params (m)
108
131
if length (ps) > 2
@@ -150,3 +173,43 @@ _any(f, x::Number) = f(x)
150
173
# _any(f, x) = false
151
174
152
175
_all (f, xs) = ! _any (! f, xs)
176
+
177
+ #=
178
+
179
+ julia> struct Tmp2; x; y; end; Flux.@functor Tmp2
180
+
181
+ # Before, notice Array(), NamedTuple(), and values
182
+
183
+ julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
184
+ Chain(
185
+ Tmp2(
186
+ Array(
187
+ Dense(2 => 3), # 9 parameters
188
+ [0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters
189
+ ),
190
+ NamedTuple(
191
+ 1:3, # 3 parameters
192
+ Dense(3 => 4), # 16 parameters
193
+ [0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters
194
+ ),
195
+ ),
196
+ ) # Total: 7 arrays, 43 parameters, 644 bytes.
197
+
198
+ # After, (; x=, y=, z=) and "3-element Array"
199
+
200
+ julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
201
+ Chain(
202
+ Tmp2(
203
+ [
204
+ Dense(2 => 3), # 9 parameters
205
+ 4×3 Adjoint, # 12 parameters
206
+ ],
207
+ (;
208
+ x = 3-element UnitRange, # 3 parameters
209
+ y = Dense(3 => 4), # 16 parameters
210
+ z = 3-element Array, # 3 parameters
211
+ ),
212
+ ),
213
+ ) # Total: 7 arrays, 43 parameters, 644 bytes.
214
+
215
+ =#
0 commit comments