Skip to content

Commit 55a3ddc

Browse files
authored
feat: error hints for Enzyme (#788)
* feat: error hints for Enzyme * Clearer message * Add tests * Simpler show * Unreleased * Move
1 parent 744dce4 commit 55a3ddc

File tree

6 files changed

+123
-4
lines changed

6 files changed

+123
-4
lines changed

DifferentiationInterface/CHANGELOG.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10-
### Changed
10+
### Added
1111

12-
- Allocate Enzyme shadow memory during preparation ([#782])
12+
- Error hints for Enzyme ([#788])
1313

1414
## [0.6.53] - 2025-05-07
1515

16+
### Changed
17+
18+
- Allocate Enzyme shadow memory during preparation ([#782])
19+
1620
[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...main
1721
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
1822

23+
[#788]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/788
1924
[#782]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/782

DifferentiationInterface/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.53"
4+
version = "0.6.54"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -56,7 +56,7 @@ ADTypes = "1.13.0"
5656
ChainRulesCore = "1.23.0"
5757
DiffResults = "1.1.0"
5858
Diffractor = "=0.2.6"
59-
Enzyme = "0.13.17"
59+
Enzyme = "0.13.39"
6060
EnzymeCore = "0.8.8"
6161
ExplicitImports = "1.10.1"
6262
FastDifferentiation = "0.4.3"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ using EnzymeCore:
3030
Split,
3131
WithPrimal
3232
using Enzyme:
33+
Enzyme,
3334
autodiff,
3435
autodiff_thunk,
3536
create_shadows,
@@ -53,4 +54,6 @@ include("forward_twoarg.jl")
5354
include("reverse_onearg.jl")
5455
include("reverse_twoarg.jl")
5556

57+
include("init.jl")
58+
5659
end # module
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
const HINT_END = "\n\nThis hint appears because DifferentiationInterface and Enzyme are both loaded. It does not necessarily imply that Enzyme is being called through DifferentiationInterface.\n\n"
2+
3+
function HINT_START(option)
4+
return "\nIf you are using Enzyme by selecting the `AutoEnzyme` object from ADTypes, you may want to try setting the `$option` option as follows:"
5+
end
6+
7+
function __init__()
8+
# robust against internal changes
9+
condition = (
10+
isdefined(Enzyme, :Compiler) &&
11+
Enzyme.Compiler isa Module &&
12+
isdefined(Enzyme.Compiler, :EnzymeError) &&
13+
Enzyme.Compiler.EnzymeError isa DataType
14+
)
15+
condition || return nothing
16+
# see https://github.com/JuliaLang/julia/issues/58367 for why this isn't easier
17+
for n in names(Enzyme.Compiler; all=true)
18+
T = getfield(Enzyme.Compiler, n)
19+
if T isa DataType && T <: Enzyme.Compiler.EnzymeError
20+
# robust against internal changes
21+
Base.Experimental.register_error_hint(T) do io, exc
22+
if occursin("EnzymeMutabilityException", string(nameof(T)))
23+
printstyled(io, HINT_START("function_annotation"); bold=true)
24+
printstyled(
25+
io,
26+
"\n\n\tAutoEnzyme(; function_annotation=Enzyme.Duplicated)";
27+
color=:cyan,
28+
bold=true,
29+
)
30+
printstyled(io, HINT_END; italic=true)
31+
elseif occursin("EnzymeRuntimeActivityError", string(nameof(T)))
32+
printstyled(io, HINT_START("mode"); bold=true)
33+
printstyled(
34+
io,
35+
"\n\n\tAutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))\n\tAutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))";
36+
color=:cyan,
37+
bold=true,
38+
)
39+
printstyled(io, HINT_END; italic=true)
40+
end
41+
end
42+
end
43+
end
44+
end

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,56 @@ end
148148
f_nocontext, AutoEnzyme(; mode=Enzyme.Reverse), rand(10), ConstantOrCache(nothing)
149149
)
150150
end
151+
152+
@testset "Hints" begin
153+
@testset "MutabilityError" begin
154+
f = let
155+
cache = [0.0]
156+
x -> sum(copyto!(cache, x))
157+
end
158+
159+
e = nothing
160+
try
161+
gradient(f, AutoEnzyme(), [1.0])
162+
catch e
163+
end
164+
msg = sprint(showerror, e)
165+
@test occursin("AutoEnzyme", msg)
166+
@test occursin("function_annotation", msg)
167+
@test occursin("ADTypes", msg)
168+
@test occursin("DifferentiationInterface", msg)
169+
end
170+
171+
@testset "RuntimeActivityError" begin
172+
function g(active_var, constant_var, cond)
173+
if cond
174+
return active_var
175+
else
176+
return constant_var
177+
end
178+
end
179+
180+
function h(active_var, constant_var, cond)
181+
return [g(active_var, constant_var, cond), g(active_var, constant_var, cond)]
182+
end
183+
184+
e = nothing
185+
try
186+
pushforward(
187+
h,
188+
AutoEnzyme(; mode=Enzyme.Forward),
189+
[1.0],
190+
([1.0],),
191+
Constant([1.0]),
192+
Constant(true),
193+
)
194+
catch e
195+
end
196+
msg = sprint(showerror, e)
197+
@test occursin("AutoEnzyme", msg)
198+
@test occursin("mode", msg)
199+
@test occursin("set_runtime_activity", msg)
200+
@test occursin("ADTypes", msg)
201+
@test occursin("DifferentiationInterface", msg)
202+
end
203+
end
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using ADTypes
2+
using DifferentiationInterface
3+
import DifferentiationInterface as DI
4+
using Test
5+
6+
@testset "Missing backend" begin
7+
e = nothing
8+
try
9+
gradient(sum, AutoZygote(), [1.0])
10+
catch e
11+
end
12+
msg = sprint(showerror, e)
13+
@test occursin("import Zygote", msg)
14+
end

0 commit comments

Comments
 (0)