Skip to content

Commit 99cf2ee

Browse files
authored
Merge pull request #108 from JuliaDiff/atone-for-our-sins
removes `@pure` and some miscelanious cleanups
2 parents 342a2cd + eb7f7f6 commit 99cf2ee

File tree

5 files changed

+33
-35
lines changed

5 files changed

+33
-35
lines changed

src/extra_rules.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function (::∂⃖{N})(f::typeof(*), args...) where {N}
7979
end
8080
return z
8181
else
82-
∂⃖p = ∂⃖{minus1(N)}()
82+
∂⃖p = ∂⃖{N-1}()
8383
@destruct z, z̄ = ∂⃖p(rrule_times, f, args...)
8484
if z === nothing
8585
return ∂⃖recurse{N}()(f, args...)
@@ -130,15 +130,15 @@ end
130130
struct NonDiffEven{N, O, P}; end
131131
struct NonDiffOdd{N, O, P}; end
132132

133-
(::NonDiffOdd{N, O, P})(Δ) where {N, O, P} = (ntuple(_->ZeroTangent(), N), NonDiffEven{N, plus1(O), P}())
134-
(::NonDiffEven{N, O, P})(Δ...) where {N, O, P} = (ZeroTangent(), NonDiffOdd{N, plus1(O), P}())
133+
(::NonDiffOdd{N, O, P})(Δ) where {N, O, P} = (ntuple(_->ZeroTangent(), N), NonDiffEven{N, O+1, P}())
134+
(::NonDiffEven{N, O, P})(Δ...) where {N, O, P} = (ZeroTangent(), NonDiffOdd{N, O+1, P}())
135135
(::NonDiffOdd{N, O, O})(Δ) where {N, O} = ntuple(_->ZeroTangent(), N)
136136

137137
# This should not happen
138138
(::NonDiffEven{N, O, O})(Δ...) where {N, O} = error()
139139

140-
@Base.pure function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...)
141-
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
140+
@Base.assume_effects :total function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...)
141+
Core.apply_type(head, args...), NonDiffOdd{length(args)+2, 1, 1}()
142142
end
143143

144144
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.tuple), args...)

src/interface.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,9 @@ Base.show(io::IO, f::PrimeDerivativeBack{N}) where {N} = print(io, f.f, "'"^N)
143143

144144
# This improves performance for nested derivatives by short cutting some
145145
# recursion into the PrimeDerivative constructor
146-
@Base.pure minus1(N) = N - 1
147-
@Base.pure plus1(N) = N + 1
148-
lower_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{minus1(N),T}(getfield(f, :f))
146+
lower_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{N-1,T}(getfield(f, :f))
149147
lower_pd(f::PrimeDerivativeBack{1}) = getfield(f, :f)
150-
raise_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{plus1(N),T}(getfield(f, :f))
148+
raise_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{N+1,T}(getfield(f, :f))
151149

152150
ChainRulesCore.rrule(::typeof(lower_pd), f) = lower_pd(f), Δ->(ZeroTangent(), Δ)
153151
ChainRulesCore.rrule(::typeof(raise_pd), f) = raise_pd(f), Δ->(ZeroTangent(), Δ)
@@ -170,8 +168,8 @@ end
170168
PrimeDerivativeFwd(f) = PrimeDerivativeFwd{1, typeof(f)}(f)
171169
PrimeDerivativeFwd(f::PrimeDerivativeFwd{N, T}) where {N, T} = raise_pd(f)
172170

173-
lower_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = (error(); PrimeDerivativeFwd{minus1(N),T}(getfield(f, :f)))
174-
raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{plus1(N),T}(getfield(f, :f))
171+
lower_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = (error(); PrimeDerivativeFwd{N-1,T}(getfield(f, :f)))
172+
raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{N+1,T}(getfield(f, :f))
175173

176174
(f::PrimeDerivativeFwd{0})(x) = getfield(f, :f)(x)
177175

src/stage1/forward.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ primal(z::ZeroTangent) = ZeroTangent()
2020
first_partial(x) = partial(x, 1)
2121

2222
shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} =
23-
UniformBundle{minus1(N), <:Any}(UniformBundle{1, B}(b.primal, b.tangent.val),
23+
UniformBundle{N-1, <:Any}(UniformBundle{1, B}(b.primal, b.tangent.val),
2424
UniformBundle{1, U}(b.tangent.val, b.tangent.val))
2525

2626
function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
@@ -30,7 +30,7 @@ function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
3030
end
3131
ExplicitTangentBundle{N-1}(
3232
ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)),
33-
ntuple(_sdown, 2^(N-1)-1))
33+
ntuple(_sdown, 1<<(N-1)-1))
3434
end
3535

3636
function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
@@ -86,7 +86,7 @@ function shuffle_up(r::CompositeBundle{N}) where {N}
8686
else
8787
return TangentBundle{N+1}(r.tup[1].primal,
8888
(r.tup[1].tangent.partials..., primal(b),
89-
ntuple(i->partial(b,i), 2^(N+1)-1)...))
89+
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
9090
end
9191
end
9292

@@ -131,10 +131,10 @@ function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
131131
end
132132

133133
function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
134-
∂☆p = ∂☆{minus1(N)}()
134+
∂☆p = ∂☆{N-1}()
135135
downargs = map(shuffle_down, args)
136-
tupargs = ∂vararg{minus1(N)}()(map(first_partial, downargs)...)
137-
∂☆p(ZeroBundle{minus1(N)}(frule), #= ZeroBundle{minus1(N)}(DiffractorRuleConfig()), =# tupargs, map(primal, downargs)...)
136+
tupargs = ∂vararg{N-1}()(map(first_partial, downargs)...)
137+
∂☆p(ZeroBundle{N-1}(frule), #= ZeroBundle{N-1}(DiffractorRuleConfig()), =# tupargs, map(primal, downargs)...)
138138
end
139139

140140
function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}

src/stage1/generated.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ struct ∂⃖weaveInnerOdd{N, O}; b̄; end
100100
end
101101
@Base.constprop :aggressive function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O}
102102
@destruct c, c̄ = w....)
103-
return (c̄, c), ∂⃖weaveInnerEven{plus1(N), O}()
103+
return (c̄, c), ∂⃖weaveInnerEven{N+1, O}()
104104
end
105105
struct ∂⃖weaveInnerEven{N, O}; end
106106
@Base.constprop :aggressive function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O}
107107
@destruct y, ȳ = Δ′(x...)
108-
return y, ∂⃖weaveInnerOdd{plus1(N), O}(ȳ)
108+
return y, ∂⃖weaveInnerOdd{N+1, O}(ȳ)
109109
end
110110

111111
struct ∂⃖weaveOuterOdd{N, O}; end
@@ -114,15 +114,15 @@ struct ∂⃖weaveOuterOdd{N, O}; end
114114
end
115115
@Base.constprop :aggressive function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O}
116116
@destruct α, ᾱ = Δ′′′(Δ′′)
117-
return (NoTangent(), α...), ∂⃖weaveOuterEven{plus1(N), O}(ᾱ)
117+
return (NoTangent(), α...), ∂⃖weaveOuterEven{N+1, O}(ᾱ)
118118
end
119119
struct ∂⃖weaveOuterEven{N, O}; ᾱ end
120120
@Base.constprop :aggressive function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O}
121-
return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{plus1(N), O}()
121+
return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{N+1, O}()
122122
end
123123

124124
function (::∂⃖{N})(::∂⃖{1}, args...) where {N}
125-
@destruct (a, ā) = ∂⃖{plus1(N)}()(args...)
125+
@destruct (a, ā) = ∂⃖{N+1}()(args...)
126126
let O = c_order(N)
127127
(a, Protected{N}(@opaque Δ->begin
128128
(b, b̄) = (Δ)
@@ -187,10 +187,10 @@ end
187187
(::∂⃖rruleD{N, N})(Δ...) where {N} = error("Should not be reached")
188188

189189
# ∂⃖rrule
190-
@Base.pure term_depth(N) = 2^(N-2)
190+
term_depth(N) = 1<<(N-2)
191191
function (::∂⃖rrule{N})(z, z̄) where {N}
192192
@destruct (y, ȳ) = z
193-
y, ∂⃖rruleA{term_depth(N), 1}(∂⃖{minus1(N)}(), ȳ, z̄)
193+
y, ∂⃖rruleA{term_depth(N), 1}(∂⃖{N-1}(), ȳ, z̄)
194194
end
195195

196196
function (::∂⃖{N})(f::Core.IntrinsicFunction, args...) where {N}
@@ -216,7 +216,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
216216
end
217217
return z
218218
else
219-
∂⃖p = ∂⃖{minus1(N)}()
219+
∂⃖p = ∂⃖{N-1}()
220220
@destruct z, z̄ = ∂⃖p(rrule, f, args...)
221221
if z === nothing
222222
return ∂⃖recurse{N}()(f, args...)
@@ -230,7 +230,7 @@ function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) wher
230230
Tuple{Any, Any}(∂⃖{1}()(f, args...))
231231
end
232232

233-
@Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)
233+
@Base.assume_effects :total function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)
234234
return rrule(Core.apply_type, head, args...)
235235
end
236236

@@ -283,8 +283,8 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end
283283
EvenOddEven{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddEven{O, P, F, G}(f, g)
284284
struct EvenOddOdd{O, P, F, G}; f::F; g::G; end
285285
EvenOddOdd{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddOdd{O, P, F, G}(f, g)
286-
@Base.constprop :aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g))
287-
@Base.constprop :aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g))
286+
@Base.constprop :aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{O+1, P, F, G}(o.f, o.g))
287+
@Base.constprop :aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g...), EvenOddOdd{O+1, P, F, G}(e.f, e.g))
288288
@Base.constprop :aggressive (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ)
289289

290290

@@ -362,11 +362,11 @@ struct ApplyOdd{O, P}; u; ∂⃖f; end
362362
struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end
363363
@Base.constprop :aggressive function (a::ApplyOdd{O, P})(Δ) where {O, P}
364364
r, ∂⃖∂⃖f = a.∂⃖f(Δ)
365-
(a.u(r), ApplyEven{plus1(O), P}(a.u, ∂⃖∂⃖f))
365+
(a.u(r), ApplyEven{O+1, P}(a.u, ∂⃖∂⃖f))
366366
end
367367
@Base.constprop :aggressive function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P}
368368
r, ∂⃖∂⃖∂⃖f = Core._apply_iterate(iterate, a.∂⃖∂⃖f, (ff,), args...)
369-
(r, ApplyOdd{plus1(O), P}(a.u, ∂⃖∂⃖∂⃖f))
369+
(r, ApplyOdd{O+1, P}(a.u, ∂⃖∂⃖∂⃖f))
370370
end
371371
@Base.constprop :aggressive function (a::ApplyOdd{O, O})(Δ) where {O}
372372
r = a.∂⃖f(Δ)
@@ -380,10 +380,10 @@ function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Unio
380380
end
381381

382382

383-
@Base.pure c_order(N::Int) = 2^N - 1
383+
c_order(N::Int) = 1<<N - 1
384384

385-
@Base.pure function (::∂⃖{N})(::typeof(Core.apply_type), head, args...) where {N}
386-
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, c_order(N)}()
385+
@Base.assume_effects :total function (::∂⃖{N})(::typeof(Core.apply_type), head, args...) where {N}
386+
Core.apply_type(head, args...), NonDiffOdd{length(args)+2, 1, c_order(N)}()
387387
end
388388

389389
@Base.constprop :aggressive lifted_getfield(x, s) = getfield(x, s)

src/stage1/mixed.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ end
1414
function (x::∂⃖composeOdd)(Δ)
1515
b, ∂b = x.b(Δ)
1616
a, ∂a = x.a(b[end])
17-
a, ∂⃖composeEven{N, plus1(N)}(∂a, ∂b)
17+
a, ∂⃖composeEven{N, N+1}(∂a, ∂b)
1818
end
1919

2020
function (x::∂⃖composeEven)(args...)
2121
a, ∂a = x.a(args...)
2222
b, ∂b = x.b(a)
23-
b, ∂⃖composeOdd{N, plus1(N)}(∂a, ∂b)
23+
b, ∂⃖composeOdd{N, N+1}(∂a, ∂b)
2424
end
2525

2626
function (x::∂⃖composeOdd{N,N})(Δ) where {N}

0 commit comments

Comments
 (0)