Skip to content

Commit 7a72d84

Browse files
committed
make rules render better
1 parent 0fabda9 commit 7a72d84

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

src/ChainRulesCore.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
33

44
export AbstractRule, Rule, frule, rrule
55
export @scalar_rule, @thunk
6-
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
6+
export extern, cast, store!
7+
export Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
78
export NO_FIELDS_RULE, ZERO_RULE
89

910
include("differentials.jl")

src/rules.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)
160160

161161
(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...)
162162

163+
Base.show(io::IO, rule::Rule{<:Any, Nothing}) = print(io, "Rule($(rule.f))")
164+
Base.show(io::IO, rule::Rule) = print(io, "Rule($(rule.f), $(rule.u))")
165+
163166
# Specialized accumulation
164167
# TODO: Does this need to be overdubbed in the rule context?
165168
accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)
@@ -173,7 +176,7 @@ The most notable use for this is for the reverse-mode derivative with respect to
173176
function itself, when that function is not a closure.
174177
The rule returns an empty `NamedTuple` for all inputs.
175178
"""
176-
const NO_FIELDS_RULE = Rule((args...)->NamedTuple())
179+
const NO_FIELDS_RULE = Rule(function no_fields(args...) NamedTuple() end)
177180

178181
"""
179182
ZERO_RULE
@@ -182,7 +185,7 @@ This is a rule that returns `Zero()` regardless of input.
182185
The most notable use for this is for the forward-mode derivative with respect to the
183186
function itself, when that function is not a closure.
184187
"""
185-
const ZERO_RULE = Rule((args...)->Zero())
188+
const ZERO_RULE = Rule(function always_zero(args...) Zero() end)
186189

187190

188191

0 commit comments

Comments
 (0)