Skip to content

Commit 7290a2a

Browse files
committed
Add Debug Mode
1 parent a4e078a commit 7290a2a

File tree

5 files changed

+49
-6
lines changed

5 files changed

+49
-6
lines changed

docs/src/debug_mode.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Debug Mode
2+
3+
ChainRules supports a [`debug_mode`](@ref) which you can use while writing new rules.
4+
It provides better error messages.
5+
If you are developing some new rules, and you get a weird error message,
6+
it is worth enabling debug mode.
7+
8+
There is some overhead to having it enabled, so it is disabled by default.
9+
10+
To enable redefine the `debug_mode()` function to return `true`.
11+
```julia
12+
ChainRulesCore.debug_mode() = true
13+
```

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero
88
export NO_FIELDS
99

1010
include("compat.jl")
11+
include("debug_mode.jl")
1112

1213
include("differentials/abstract_differential.jl")
1314
include("differentials/abstract_zero.jl")

src/debug_mode.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
debug_mode()
3+
4+
Determines if ChainRulesCore is in `debug_mode`.
5+
Defaults to `false`, but if the user redefines it to return `true` then extra
6+
information will be shown when errors occur.
7+
8+
Enable via:
9+
```
10+
ChainRulesCore.debug_mode() = true
11+
```
12+
"""
13+
debug_mode() = false

src/differential_arithmetic.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,15 @@ function Base.:+(a::Composite{P}, b::Composite{P}) where P
8686
return Composite{P, typeof(data)}(data)
8787
end
8888
function Base.:+(a::P, d::Composite{P}) where P
89-
try
90-
return construct(P, elementwise_add(backing(a), backing(d)))
91-
catch err
92-
throw(PrimalAdditionFailedException(a, d, err))
89+
net_backing = elementwise_add(backing(a), backing(d))
90+
if debug_mode()
91+
try
92+
return construct(P, net_backing)
93+
catch err
94+
throw(PrimalAdditionFailedException(a, d, err))
95+
end
96+
else
97+
return construct(P, net_backing)
9398
end
9499
end
95100
Base.:+(a::Composite{P}, b::P) where P = b + a

test/differentials/composite.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,19 @@ end
128128
@testset "+ with Primals, with inner constructor" begin
129129
value = StructWithInvariant(10.0)
130130
diff = Composite{StructWithInvariant}(x=2.0, x2=6.0)
131-
@test_throws ChainRulesCore.PrimalAdditionFailedException (value + diff)
132-
@test_throws ChainRulesCore.PrimalAdditionFailedException (diff + value)
131+
132+
@testset "with and without debug mode" begin
133+
@assert ChainRulesCore.debug_mode() == false
134+
@test_throws MethodError (value + diff)
135+
@test_throws MethodError (diff + value)
136+
137+
ChainRulesCore.debug_mode() = true # enable debug mode
138+
@test_throws ChainRulesCore.PrimalAdditionFailedException (value + diff)
139+
@test_throws ChainRulesCore.PrimalAdditionFailedException (diff + value)
140+
ChainRulesCore.debug_mode() = false # disable it again
141+
end
142+
143+
133144

134145
# Now we define constuction for ChainRulesCore.jl's purposes:
135146
# It is going to determine the root quanity of the invarient

0 commit comments

Comments
 (0)