Skip to content

Commit dc5eca2

Browse files
committed
count non-trainable parameters
1 parent 8f11ab8 commit dc5eca2

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/layers/show.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,20 @@ function _big_finale(io::IO, m)
6565
if length(ps) > 2
6666
pars = underscorise(sum(length, ps))
6767
bytes = Base.format_bytes(Base.summarysize(m))
68-
printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black)
68+
noncnt = _childarray_sum(_->1, m) - length(ps)
69+
if noncnt > 0
70+
nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps))
71+
printstyled(io, " "^19, "# Total: ", length(ps), " trainable arrays, ", pars, " parameters,\n"; color=:light_black)
72+
printstyled(io, " "^20, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, total size ", bytes; color=:light_black)
73+
else
74+
printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black)
75+
end
6976
end
7077
end
7178

79+
_childarray_sum(f, x::AbstractArray) = f(x) # count includes non-trainable arrays excluded from params
80+
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))
81+
7282
# utility functions
7383

7484
underscorise(n::Integer) =

0 commit comments

Comments
 (0)