Skip to content

Commit 0a11f37

Browse files
committed
removes @pure and some miscelanious cleanups
1 parent 5b3a846 commit 0a11f37

File tree

6 files changed

+36
-38
lines changed

6 files changed

+36
-38
lines changed

Manifest.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
2727

2828
[[deps.ChainRules]]
2929
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
30-
git-tree-sha1 = "c46adabdd0348f0ee8de91142cfc4a72a613ac0a"
30+
git-tree-sha1 = "fdde4d8a31cf82b1d136cf6cb53924e8744a832b"
3131
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
32-
version = "1.46.1"
32+
version = "1.47.0"
3333

3434
[[deps.ChainRulesCore]]
3535
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
@@ -265,9 +265,9 @@ version = "1.10.0"
265265

266266
[[deps.StaticArrays]]
267267
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
268-
git-tree-sha1 = "129703d62117c374c4f2db6d13a027741c46eafd"
268+
git-tree-sha1 = "cee507162ecbb677450f20058ca83bd559b6b752"
269269
uuid = "90137ffa-7385-5640-81b9-e52037218182"
270-
version = "1.5.13"
270+
version = "1.5.14"
271271

272272
[[deps.StaticArraysCore]]
273273
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"

src/extra_rules.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function (::∂⃖{N})(f::typeof(*), args...) where {N}
7979
end
8080
return z
8181
else
82-
∂⃖p = ∂⃖{minus1(N)}()
82+
∂⃖p = ∂⃖{N-1}()
8383
@destruct z, z̄ = ∂⃖p(rrule_times, f, args...)
8484
if z === nothing
8585
return ∂⃖recurse{N}()(f, args...)
@@ -130,15 +130,15 @@ end
130130
struct NonDiffEven{N, O, P}; end
131131
struct NonDiffOdd{N, O, P}; end
132132

133-
(::NonDiffOdd{N, O, P})(Δ) where {N, O, P} = (ntuple(_->ZeroTangent(), N), NonDiffEven{N, plus1(O), P}())
134-
(::NonDiffEven{N, O, P})(Δ...) where {N, O, P} = (ZeroTangent(), NonDiffOdd{N, plus1(O), P}())
133+
(::NonDiffOdd{N, O, P})(Δ) where {N, O, P} = (ntuple(_->ZeroTangent(), N), NonDiffEven{N, O+1, P}())
134+
(::NonDiffEven{N, O, P})(Δ...) where {N, O, P} = (ZeroTangent(), NonDiffOdd{N, O+1, P}())
135135
(::NonDiffOdd{N, O, O})(Δ) where {N, O} = ntuple(_->ZeroTangent(), N)
136136

137137
# This should not happen
138138
(::NonDiffEven{N, O, O})(Δ...) where {N, O} = error()
139139

140-
@Base.pure function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...)
141-
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
140+
@Base.assume_effects :total function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...)
141+
Core.apply_type(head, args...), NonDiffOdd{length(args)+2, 1, 1}()
142142
end
143143

144144
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.tuple), args...)

src/interface.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,9 @@ Base.show(io::IO, f::PrimeDerivativeBack{N}) where {N} = print(io, f.f, "'"^N)
143143

144144
# This improves performance for nested derivatives by short cutting some
145145
# recursion into the PrimeDerivative constructor
146-
@Base.pure minus1(N) = N - 1
147-
@Base.pure plus1(N) = N + 1
148-
lower_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{minus1(N),T}(getfield(f, :f))
146+
lower_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{N-1,T}(getfield(f, :f))
149147
lower_pd(f::PrimeDerivativeBack{1}) = getfield(f, :f)
150-
raise_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{plus1(N),T}(getfield(f, :f))
148+
raise_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{N+1,T}(getfield(f, :f))
151149

152150
ChainRulesCore.rrule(::typeof(lower_pd), f) = lower_pd(f), Δ->(ZeroTangent(), Δ)
153151
ChainRulesCore.rrule(::typeof(raise_pd), f) = raise_pd(f), Δ->(ZeroTangent(), Δ)
@@ -170,8 +168,8 @@ end
170168
PrimeDerivativeFwd(f) = PrimeDerivativeFwd{1, typeof(f)}(f)
171169
PrimeDerivativeFwd(f::PrimeDerivativeFwd{N, T}) where {N, T} = raise_pd(f)
172170

173-
lower_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = (error(); PrimeDerivativeFwd{minus1(N),T}(getfield(f, :f)))
174-
raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{plus1(N),T}(getfield(f, :f))
171+
lower_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = (error(); PrimeDerivativeFwd{N-1,T}(getfield(f, :f)))
172+
raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{N+1,T}(getfield(f, :f))
175173

176174
(f::PrimeDerivativeFwd{0})(x) = getfield(f, :f)(x)
177175

src/stage1/forward.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ my_frule(::ZeroBundle{1, typeof(my_frule)}, args::ATB{1}...) = nothing
2323
(::∂☆{N})(::ZeroBundle{N, typeof(my_frule)}, ::ZeroBundle{N, ZeroBundle{1, typeof(my_frule)}}, args::ATB{N}...) where {N} = ZeroBundle{N}(nothing)
2424

2525
shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} =
26-
UniformBundle{minus1(N), <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val)
26+
UniformBundle{N-1, <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val)
2727

2828
function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
2929
# N.B: This depends on the special properties of the canonical tangent index order
3030
ExplicitTangentBundle{N-1}(
3131
ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)),
32-
ntuple(2^(N-1)-1) do i
32+
ntuple(1<<(N-1)-1) do i
3333
ExplicitTangentBundle{1}(partial(b, 2*i), (partial(b, 2*i+1),))
3434
end)
3535
end
@@ -86,7 +86,7 @@ function shuffle_up(r::CompositeBundle{N}) where {N}
8686
else
8787
return TangentBundle{N+1}(r.tup[1].primal,
8888
(r.tup[1].tangent.partials..., primal(b),
89-
ntuple(i->partial(b,i), 2^(N+1)-1)...))
89+
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
9090
end
9191
end
9292

@@ -124,8 +124,8 @@ function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
124124
end
125125

126126
function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
127-
∂☆p = ∂☆{minus1(N)}()
128-
∂☆p(ZeroBundle{minus1(N)}(my_frule), map(shuffle_down, args)...)
127+
∂☆p = ∂☆{N-1}()
128+
∂☆p(ZeroBundle{N-1}(my_frule), map(shuffle_down, args)...)
129129
end
130130

131131
function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}

src/stage1/generated.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ struct ∂⃖weaveInnerOdd{N, O}; b̄; end
101101
end
102102
@Base.constprop :aggressive function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O}
103103
@destruct c, c̄ = w....)
104-
return (c̄, c), ∂⃖weaveInnerEven{plus1(N), O}()
104+
return (c̄, c), ∂⃖weaveInnerEven{N+1, O}()
105105
end
106106
struct ∂⃖weaveInnerEven{N, O}; end
107107
@Base.constprop :aggressive function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O}
108108
@destruct y, ȳ = Δ′(x...)
109-
return y, ∂⃖weaveInnerOdd{plus1(N), O}(ȳ)
109+
return y, ∂⃖weaveInnerOdd{N+1, O}(ȳ)
110110
end
111111

112112
struct ∂⃖weaveOuterOdd{N, O}; end
@@ -115,15 +115,15 @@ struct ∂⃖weaveOuterOdd{N, O}; end
115115
end
116116
@Base.constprop :aggressive function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O}
117117
@destruct α, ᾱ = Δ′′′(Δ′′)
118-
return (NoTangent(), α...), ∂⃖weaveOuterEven{plus1(N), O}(ᾱ)
118+
return (NoTangent(), α...), ∂⃖weaveOuterEven{N+1, O}(ᾱ)
119119
end
120120
struct ∂⃖weaveOuterEven{N, O}; ᾱ end
121121
@Base.constprop :aggressive function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O}
122-
return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{plus1(N), O}()
122+
return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{N+1, O}()
123123
end
124124

125125
function (::∂⃖{N})(::∂⃖{1}, args...) where {N}
126-
@destruct (a, ā) = ∂⃖{plus1(N)}()(args...)
126+
@destruct (a, ā) = ∂⃖{N+1}()(args...)
127127
let O = c_order(N)
128128
(a, Protected{N}(@opaque Δ->begin
129129
(b, b̄) = (Δ)
@@ -188,10 +188,10 @@ end
188188
(::∂⃖rruleD{N, N})(Δ...) where {N} = error("Should not be reached")
189189

190190
# ∂⃖rrule
191-
@Base.pure term_depth(N) = 2^(N-2)
191+
term_depth(N) = 1<<(N-2)
192192
function (::∂⃖rrule{N})(z, z̄) where {N}
193193
@destruct (y, ȳ) = z
194-
y, ∂⃖rruleA{term_depth(N), 1}(∂⃖{minus1(N)}(), ȳ, z̄)
194+
y, ∂⃖rruleA{term_depth(N), 1}(∂⃖{N-1}(), ȳ, z̄)
195195
end
196196

197197
function (::∂⃖{N})(f::Core.IntrinsicFunction, args...) where {N}
@@ -217,7 +217,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
217217
end
218218
return z
219219
else
220-
∂⃖p = ∂⃖{minus1(N)}()
220+
∂⃖p = ∂⃖{N-1}()
221221
@destruct z, z̄ = ∂⃖p(rrule, f, args...)
222222
if z === nothing
223223
return ∂⃖recurse{N}()(f, args...)
@@ -231,7 +231,7 @@ function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) wher
231231
Tuple{Any, Any}(∂⃖{1}()(f, args...))
232232
end
233233

234-
@Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)
234+
@Base.assume_effects :total function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)
235235
return rrule(Core.apply_type, head, args...)
236236
end
237237

@@ -284,8 +284,8 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end
284284
EvenOddEven{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddEven{O, P, F, G}(f, g)
285285
struct EvenOddOdd{O, P, F, G}; f::F; g::G; end
286286
EvenOddOdd{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddOdd{O, P, F, G}(f, g)
287-
@Base.constprop :aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g))
288-
@Base.constprop :aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g))
287+
@Base.constprop :aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{O+1, P, F, G}(o.f, o.g))
288+
@Base.constprop :aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g...), EvenOddOdd{O+1, P, F, G}(e.f, e.g))
289289
@Base.constprop :aggressive (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ)
290290

291291

@@ -363,11 +363,11 @@ struct ApplyOdd{O, P}; u; ∂⃖f; end
363363
struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end
364364
@Base.constprop :aggressive function (a::ApplyOdd{O, P})(Δ) where {O, P}
365365
r, ∂⃖∂⃖f = a.∂⃖f(Δ)
366-
(a.u(r), ApplyEven{plus1(O), P}(a.u, ∂⃖∂⃖f))
366+
(a.u(r), ApplyEven{O+1, P}(a.u, ∂⃖∂⃖f))
367367
end
368368
@Base.constprop :aggressive function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P}
369369
r, ∂⃖∂⃖∂⃖f = Core._apply_iterate(iterate, a.∂⃖∂⃖f, (ff,), args...)
370-
(r, ApplyOdd{plus1(O), P}(a.u, ∂⃖∂⃖∂⃖f))
370+
(r, ApplyOdd{O+1, P}(a.u, ∂⃖∂⃖∂⃖f))
371371
end
372372
@Base.constprop :aggressive function (a::ApplyOdd{O, O})(Δ) where {O}
373373
r = a.∂⃖f(Δ)
@@ -381,10 +381,10 @@ function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Unio
381381
end
382382

383383

384-
@Base.pure c_order(N::Int) = 2^N - 1
384+
c_order(N::Int) = 1<<N - 1
385385

386-
@Base.pure function (::∂⃖{N})(::typeof(Core.apply_type), head, args...) where {N}
387-
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, c_order(N)}()
386+
@Base.assume_effects :total function (::∂⃖{N})(::typeof(Core.apply_type), head, args...) where {N}
387+
Core.apply_type(head, args...), NonDiffOdd{length(args)+2, 1, c_order(N)}()
388388
end
389389

390390
@Base.constprop :aggressive lifted_getfield(x, s) = getfield(x, s)

src/stage1/mixed.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ end
1414
function (x::∂⃖composeOdd)(Δ)
1515
b, ∂b = x.b(Δ)
1616
a, ∂a = x.a(b[end])
17-
a, ∂⃖composeEven{N, plus1(N)}(∂a, ∂b)
17+
a, ∂⃖composeEven{N, N+1}(∂a, ∂b)
1818
end
1919

2020
function (x::∂⃖composeEven)(args...)
2121
a, ∂a = x.a(args...)
2222
b, ∂b = x.b(a)
23-
b, ∂⃖composeOdd{N, plus1(N)}(∂a, ∂b)
23+
b, ∂⃖composeOdd{N, N+1}(∂a, ∂b)
2424
end
2525

2626
function (x::∂⃖composeOdd{N,N})(Δ) where {N}

0 commit comments

Comments
 (0)