Skip to content

Commit fb1d4ec

Browse files
authored
Accumulate NamedTuple + Tangent (#88)
* accumulate NamedTuple + Tangent * fixup * don't test on 1.8
1 parent 9a8a788 commit fb1d4ec

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

.github/workflows/CI.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ jobs:
1717
matrix:
1818
version:
1919
- '1.7' # Lowest claimed support in Project.toml
20-
- '1' # Latest Release
20+
# - '1' # Latest Release # Testing on 1.8 gives this message:
21+
# ┌ Warning: ir verification broken. Either use 1.9 or 1.7
22+
# └ @ Diffractor ~/work/Diffractor.jl/Diffractor.jl/src/stage1/recurse.jl:889
2123
- 'nightly'
2224
os:
2325
- ubuntu-latest

src/extra_rules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,6 @@ end
266266
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk}, add!!, val)
267267
val, Δ->(NoTangent(), NoTangent(), Δ)
268268
end
269+
270+
Base.real(z::ZeroTangent) = z # TODO should be in CRC
271+
Base.real(z::NoTangent) = z

src/runtime.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,25 @@ struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}}
55
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)
66
@Base.constprop :aggressive @generated function accum(x::NamedTuple, y::NamedTuple)
77
fnames = union(fieldnames(x), fieldnames(y))
8+
isempty(fnames) && return :((;)) # code below makes () instead
89
gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent())
910
grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent())
1011
Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...)
1112
end
1213
@Base.constprop :aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...)
13-
@Base.constprop :aggressive accum(a::NoTangent, b) = b
14-
@Base.constprop :aggressive accum(a, b::NoTangent) = a
15-
@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent()
14+
@Base.constprop :aggressive accum(a::AbstractZero, b) = b
15+
@Base.constprop :aggressive accum(a, b::AbstractZero) = a
16+
@Base.constprop :aggressive accum(a::AbstractZero, b::AbstractZero) = NoTangent()
17+
18+
using ChainRulesCore: Tangent, backing
19+
20+
function accum(x::Tangent{T}, y::NamedTuple) where T
21+
# @warn "gradient is both a Tangent and a NamedTuple" x y
22+
_tangent(T, accum(backing(x), y))
23+
end
24+
accum(x::NamedTuple, y::Tangent) = accum(y, x)
25+
# This solves an ambiguity, but also avoids Tangent{ZeroTangent}() which + does not:
26+
accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing(y)))
27+
28+
_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
29+
_tangent(::Type, ::NamedTuple{()}) = NoTangent()

test/runtests.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,7 @@ end
162162
# Make sure that there's no infinite recursion in kwarg calls
163163
g_kw(;x=1.0) = sin(x)
164164
f_kw(x) = g_kw(;x)
165-
@test bwd(f_kw)(1.0) == bwd(sin)(1.0) broken=true
166-
#=
167-
MethodError: no method matching +(::Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}, ::Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}})
168-
...
169-
[2] elementwise_add(a::NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}, b::NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}})
170-
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/tangent.jl:287
171-
[3] +(a::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}}, b::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}}})
172-
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_arithmetic.jl:130
173-
=#
165+
@test bwd(f_kw)(1.0) == bwd(sin)(1.0)
174166

175167
function f_crit_edge(a, b, c, x)
176168
# A function with two critical edges. This used to trigger an issue where

0 commit comments

Comments
 (0)