Skip to content

Commit f2e3ac5

Browse files
authored
Shorten printing of Tangent and thunks (#564)
* don't print internal details of thunks * don't print super-long primal types for Tangent * add tests * version * fix Int32, comments * fix tests on 1.5
1 parent 2b650e1 commit f2e3ac5

File tree

5 files changed

+47
-4
lines changed

5 files changed

+47
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.15.1"
3+
version = "1.15.2"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/tangent_types/tangent.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,18 @@ Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h)
7575

7676
function Base.show(io::IO, tangent::Tangent{P}) where {P}
7777
print(io, "Tangent{")
78-
show(io, P)
78+
str = sprint(show, P, context = io)
79+
i = findfirst('{', str)
80+
if isnothing(i)
81+
print(io, str)
82+
else # for Tangent{T{A,B,C}}(stuff), print {A,B,C} in grey, and trim this part if longer than a line:
83+
print(io, str[1:prevind(str, i)])
84+
if length(str) < 80
85+
printstyled(io, str[i:end], color=:light_black)
86+
else
87+
printstyled(io, str[i:prevind(str, 80)], "...", color=:light_black)
88+
end
89+
end
7990
print(io, "}")
8091
if isempty(backing(tangent))
8192
print(io, "()") # so it doesn't show `NamedTuple()`

src/tangent_types/thunks.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,13 @@ end
196196

197197
function Base.show(io::IO, x::Thunk)
198198
print(io, "Thunk(")
199-
show(io, x.f)
199+
str = sprint(show, x.f, context = io) # often this name is like "ChainRules.var"#1398#1403"{Matrix{Float64}, Matrix{Float64}}"
200+
ind = findfirst("var\"#", str)
201+
if isnothing(ind) || length(str) < 80
202+
printstyled(io, str, color=:light_black)
203+
else
204+
printstyled(io, str[1:ind[5]], "...", color=:light_black)
205+
end
200206
print(io, ")")
201207
end
202208

@@ -223,7 +229,13 @@ unthunk(x::InplaceableThunk) = unthunk(x.val)
223229

224230
function Base.show(io::IO, x::InplaceableThunk)
225231
print(io, "InplaceableThunk(")
226-
show(io, x.add!)
232+
str = sprint(show, x.add!, context = io)
233+
ind = findfirst("var\"#", str) # look for auto-generated function names, often with huge types
234+
if isnothing(ind)
235+
printstyled(io, str, color=:light_black)
236+
else
237+
printstyled(io, str[1:ind[5]], "...", color=:light_black)
238+
end
227239
print(io, ", ")
228240
show(io, x.val)
229241
print(io, ")")

test/tangent_types/tangent.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,4 +380,12 @@ end
380380
c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1)
381381
@test nt + c == (; a=1, b=2.1)
382382
end
383+
384+
@testset "printing" begin
385+
t5 = Tuple(rand(3))
386+
nt3 = (x=t5, y=t5, z=nothing)
387+
tang = ProjectTo(nt3)(nt3) # moderately complicated Tangent
388+
@test contains(sprint(show, tang), "...}(x = Tangent") # gets shortened
389+
@test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole
390+
end
383391
end

test/tangent_types/thunks.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,16 @@
190190
@test scal!(2, 2.0, v, 1) == scal!(2, @thunk(2.0), v, 1)
191191
@test_throws MutateThunkException LAPACK.trsyl!('C', 'C', m, m, @thunk(m))
192192
end
193+
194+
@testset "printing" begin
195+
@test !contains(sprint(show, @thunk 1+1), "...") # short thunks not abbreviated
196+
th = let x = rand(100)
197+
@thunk x .+ x'
198+
end
199+
@test contains(sprint(show, th), "...") # but long ones are
200+
201+
@test contains(sprint(show, InplaceableThunk(mul!, th)), "mul!") # named functions left in InplaceableThunk
202+
str = sprint(show, InplaceableThunk(z -> z .+ ones(100), th))
203+
@test length(findall("...", str)) == 2 # now both halves shortened
204+
end
193205
end

0 commit comments

Comments
 (0)