Skip to content

Commit df246e8

Browse files
authored
Document about being careful about pullbacks that call themselves (#435)
1 parent 8bc2961 commit df246e8

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

docs/src/writing_good_rules.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,86 @@ Note that the function does not have to be $\mathbb{R} \rightarrow \mathbb{R}$.
229229
In fact, any number of scalar arguments is supported, as is returning a tuple of scalars.
230230

231231
See docstrings for the comprehensive usage instructions.
232+
233+
234+
### Be careful about pullbacks closures calling other methods of themselves
235+
236+
Due to [JuliaLang/Julia#40990](https://github.com/JuliaLang/julia/issues/40990), a closure calling another (or the same) method of itself often comes out uninferable (and thus effectively type-unstable).
237+
This can be avoided by moving the pullback outside the function.
238+
For example:
239+
240+
```julia
241+
double_it(x::AbstractArray) = 2 .* x
242+
243+
function ChainRulesCore.rrule(::typeof(double_it), x)
244+
double_it_pullback(ȳ::AbstractArray) = (NoTangent(), 2 .* ȳ)
245+
double_it_pullback(ȳ::AbstractThunk) = double_it_pullback(unthunk(ȳ))
246+
return double_it(x), double_it_pullback
247+
end
248+
```
249+
Ends up infering a return type of `Any`
250+
```julia
251+
julia> _, pullback = rrule(double_it, [2.0, 3.0])
252+
([4.0, 6.0], var"#double_it_pullback#8"(Core.Box(var"#double_it_pullback#8"(#= circular reference @-2 =#))))
253+
254+
julia> @code_warntype pullback(@thunk([10.0, 10.0]))
255+
Variables
256+
#self#::var"#double_it_pullback#8"
257+
ȳ::Core.Const(Thunk(var"#9#10"()))
258+
double_it_pullback::Union{}
259+
260+
Body::Any
261+
1%1 = Core.getfield(#self#, :double_it_pullback)::Core.Box
262+
%2 = Core.isdefined(%1, :contents)::Bool
263+
└── goto #3 if not %2
264+
2 ─ goto #4
265+
3 ─ Core.NewvarNode(:(double_it_pullback))
266+
└── double_it_pullback
267+
4%7 = Core.getfield(%1, :contents)::Any
268+
%8 = Main.unthunk(ȳ)::Vector{Float64}
269+
%9 = (%7)(%8)::Any
270+
└── return %9
271+
```
272+
273+
This can be solved by moving the pullbacks outside the function so they are not closures, and thus to not run into this upstream issue.
274+
In this case that is fairly simple, since this example doesn't close over anything (if it did then would need a closure calling an outside function that calls itself).
275+
276+
```julia
277+
_double_it_pullback(ȳ::AbstractArray) = (NoTangent(), 2 .* ȳ)
278+
_double_it_pullback(ȳ::AbstractThunk) = _double_it_pullback(unthunk(ȳ))
279+
280+
function ChainRulesCore.rrule(::typeof(double_it), x)
281+
return double_it(x), _double_it_pullback
282+
end
283+
```
284+
This infers just fine:
285+
```julia
286+
julia> _, pullback = rrule(double_it, [2.0, 3.0])
287+
([4.0, 6.0], _double_it_pullback)
288+
289+
julia> @code_warntype pullback(@thunk([10.0, 10.0]))
290+
Variables
291+
#self#::Core.Const(_double_it_pullback)
292+
ȳ::Core.Const(Thunk(var"#7#8"()))
293+
294+
Body::Tuple{NoTangent, Vector{Float64}}
295+
1%1 = Main.unthunk(ȳ)::Vector{Float64}
296+
%2 = Main._double_it_pullback(%1)::Core.PartialStruct(Tuple{NoTangent, Vector{Float64}}, Any[Core.Const(NoTangent()), Vector{Float64}])
297+
└── return %2
298+
```
299+
300+
Though in this particular case, it can also be solved by taking advantage of duck-typing and just writing one method.
301+
Thus avoiding the call that confuses the compiler.
302+
`Thunk`s duck-type as the type they wrap in most cases: including broadcast multiplication.
303+
304+
```julia
305+
function ChainRulesCore.rrule(::typeof(double_it), x)
306+
double_it_pullback(ȳ) = (NoTangent(), 2 .* ȳ)
307+
return double_it(x), double_it_pullback
308+
end
309+
```
310+
This infers perfectly.
311+
232312
## Write tests
233313
234314
[ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl)

0 commit comments

Comments
 (0)