@@ -65,10 +65,20 @@ function _big_finale(io::IO, m)
65
65
if length (ps) > 2
66
66
pars = underscorise (sum (length, ps))
67
67
bytes = Base. format_bytes (Base. summarysize (m))
68
- printstyled (io, " " ^ 19 , " # Total: " , length (ps), " arrays, " , pars, " parameters, " , bytes; color= :light_black )
68
+ noncnt = _childarray_sum (_-> 1 , m) - length (ps)
69
+ if noncnt > 0
70
+ nonparam = underscorise (_childarray_sum (length, m) - sum (length, ps))
71
+ printstyled (io, " " ^ 19 , " # Total: " , length (ps), " trainable arrays, " , pars, " parameters,\n " ; color= :light_black )
72
+ printstyled (io, " " ^ 20 , " # plus " , noncnt, " non-trainable, " , nonparam, " parameters, total size " , bytes; color= :light_black )
73
+ else
74
+ printstyled (io, " " ^ 19 , " # Total: " , length (ps), " arrays, " , pars, " parameters, " , bytes; color= :light_black )
75
+ end
69
76
end
70
77
end
71
78
79
+ _childarray_sum (f, x:: AbstractArray ) = f (x) # count includes non-trainable arrays excluded from params
80
+ _childarray_sum (f, x) = isleaf (x) ? 0 : sum (y -> _childarray_sum (f, y), Functors. children (x))
81
+
72
82
# utility functions
73
83
74
84
underscorise (n:: Integer ) =
0 commit comments