Skip to content

Commit 8143b40

Browse files
authored
Improve display of rules (#34)
1 parent 52fb1e0 commit 8143b40

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/rule_types.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)
160160

161161
(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...)
162162

163+
Base.show(io::IO, rule::Rule{<:Any, Nothing}) = print(io, "Rule($(rule.f))")
164+
Base.show(io::IO, rule::Rule) = print(io, "Rule($(rule.f), $(rule.u))")
165+
163166
# Specialized accumulation
164167
# TODO: Does this need to be overdubbed in the rule context?
165168
accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)
@@ -214,4 +217,3 @@ function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
214217
return WirtingerRule(primal, conjugate)
215218
end
216219
end
217-

test/rule_types.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212
@test_throws BoundsError rule[2]
1313
end
1414

15+
@testset "Rule" begin
16+
@testset "show" begin
17+
@test occursin(r"^Rule\(.*foo.*\)$", repr(Rule(function foo() 1 end)))
18+
@test occursin(r"^Rule\(.*identity.*\)$", repr(Rule(identity)))
19+
20+
@test occursin(r"^Rule\(.*identity.*\,.*\+.*\)$", repr(Rule(identity, +)))
21+
end
22+
end
23+
1524
@testset "WirtingerRule" begin
1625
myabs2(x) = abs2(x)
1726

0 commit comments

Comments
 (0)