Skip to content

Commit 9709486

Browse files
committed
1 parent 167c9eb commit 9709486

File tree

3 files changed

+46
-46
lines changed

3 files changed

+46
-46
lines changed

src/runtime.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
using ChainRulesCore
22

3-
@Base.aggressive_constprop accum(a, b) = a + b
4-
@Base.aggressive_constprop accum(a::Tuple, b::Tuple) = map(accum, a, b)
5-
@Base.aggressive_constprop @generated function accum(x::NamedTuple, y::NamedTuple)
3+
@Base.constprop :aggressive accum(a, b) = a + b
4+
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)
5+
@Base.constprop :aggressive @generated function accum(x::NamedTuple, y::NamedTuple)
66
fnames = union(fieldnames(x), fieldnames(y))
77
gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent())
88
grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent())
99
Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...)
1010
end
11-
@Base.aggressive_constprop accum(a, b, c, args...) = accum(accum(a, b), c, args...)
12-
@Base.aggressive_constprop accum(a::NoTangent, b) = b
13-
@Base.aggressive_constprop accum(a, b::NoTangent) = a
14-
@Base.aggressive_constprop accum(a::NoTangent, b::NoTangent) = NoTangent()
11+
@Base.constprop :aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...)
12+
@Base.constprop :aggressive accum(a::NoTangent, b) = b
13+
@Base.constprop :aggressive accum(a, b::NoTangent) = a
14+
@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent()

src/stage1/forward.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,37 +134,37 @@ end
134134
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)
135135

136136
# Special case rules for performance
137-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
137+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
138138
s = primal(s)
139139
TangentBundle{N}(getfield(primal(x), s),
140140
map(x->lifted_getfield(x, s), x.partials))
141141
end
142142

143-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
143+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
144144
s = primal(s)
145145
TaylorBundle{N}(getfield(primal(x), s),
146146
map(y->lifted_getfield(y, s), x.coeffs))
147147
end
148148

149-
@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
149+
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
150150
x.tup[primal(s)]
151151
end
152152

153-
@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
153+
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
154154
x.tup[Base.fieldindex(B, primal(s))]
155155
end
156156

157-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
157+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
158158
s = primal(s)
159159
TangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
160160
map(x->lifted_getfield(x, s), x.partials))
161161
end
162162

163-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
163+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
164164
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.partial)
165165
end
166166

167-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U}
167+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U}
168168
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s), primal(inbounds)), x.partial)
169169
end
170170

src/stage1/generated.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ struct Protected{N}
3838
a
3939
end
4040
(p::Protected)(args...) = getfield(p, :a)(args...)[1]
41-
@Base.aggressive_constprop (::∂⃖{N})(p::Protected{N}, args...) where {N} = getfield(p, :a)(args...)
42-
@Base.aggressive_constprop (::∂⃖{1})(p::Protected{1}, args...) = getfield(p, :a)(args...)
41+
@Base.constprop :aggressive (::∂⃖{N})(p::Protected{N}, args...) where {N} = getfield(p, :a)(args...)
42+
@Base.constprop :aggressive (::∂⃖{1})(p::Protected{1}, args...) = getfield(p, :a)(args...)
4343
(::∂⃖{N})(p::Protected, args...) where {N} = error("TODO: Can we support this?")
4444

4545
struct OpticBundle{T}
@@ -94,30 +94,30 @@ end
9494
end
9595

9696
struct ∂⃖weaveInnerOdd{N, O}; b̄; end
97-
@Base.aggressive_constprop function (w::∂⃖weaveInnerOdd{N, N})(Δ) where {N}
97+
@Base.constprop :aggressive function (w::∂⃖weaveInnerOdd{N, N})(Δ) where {N}
9898
@destruct c, c̄ = w....)
9999
return (c̄, c)
100100
end
101-
@Base.aggressive_constprop function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O}
101+
@Base.constprop :aggressive function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O}
102102
@destruct c, c̄ = w....)
103103
return (c̄, c), ∂⃖weaveInnerEven{plus1(N), O}()
104104
end
105105
struct ∂⃖weaveInnerEven{N, O}; end
106-
@Base.aggressive_constprop function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O}
106+
@Base.constprop :aggressive function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O}
107107
@destruct y, ȳ = Δ′(x...)
108108
return y, ∂⃖weaveInnerOdd{plus1(N), O}(ȳ)
109109
end
110110

111111
struct ∂⃖weaveOuterOdd{N, O}; end
112-
@Base.aggressive_constprop function (w::∂⃖weaveOuterOdd{N, N})((Δ′′, Δ′′′)) where {N}
112+
@Base.constprop :aggressive function (w::∂⃖weaveOuterOdd{N, N})((Δ′′, Δ′′′)) where {N}
113113
return (NoTangent(), Δ′′′(Δ′′)...)
114114
end
115-
@Base.aggressive_constprop function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O}
115+
@Base.constprop :aggressive function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O}
116116
@destruct α, ᾱ = Δ′′′(Δ′′)
117117
return (NoTangent(), α...), ∂⃖weaveOuterEven{plus1(N), O}(ᾱ)
118118
end
119119
struct ∂⃖weaveOuterEven{N, O}; ᾱ end
120-
@Base.aggressive_constprop function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O}
120+
@Base.constprop :aggressive function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O}
121121
return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{plus1(N), O}()
122122
end
123123

@@ -156,33 +156,33 @@ struct ∂⃖rruleB{N, O}; ᾱ; ȳ̄ ; end
156156
struct ∂⃖rruleC{N, O}; ȳ̄ ; Δ′′′; β̄ ; end
157157
struct ∂⃖rruleD{N, O}; γ̄; β̄ ; end
158158

159-
@Base.aggressive_constprop function (a::∂⃖rruleA{N, O})(Δ) where {N, O}
159+
@Base.constprop :aggressive function (a::∂⃖rruleA{N, O})(Δ) where {N, O}
160160
# TODO: Is this unthunk in the right place
161161
@destruct (α, ᾱ) = a.(a.ȳ, unthunk(Δ))
162162
(α, ∂⃖rruleB{N, O}(ᾱ, a.ȳ̄))
163163
end
164164

165-
@Base.aggressive_constprop function (b::∂⃖rruleB{N, O})(Δ′...) where {N, O}
165+
@Base.constprop :aggressive function (b::∂⃖rruleB{N, O})(Δ′...) where {N, O}
166166
@destruct ((Δ′′′, β), β̄) = b.ᾱ(Δ′)
167167
(β, ∂⃖rruleC{N, O}(b.ȳ̄, Δ′′′, β̄))
168168
end
169169

170-
@Base.aggressive_constprop function (c::∂⃖rruleC{N, O})(Δ′′) where {N, O}
170+
@Base.constprop :aggressive function (c::∂⃖rruleC{N, O})(Δ′′) where {N, O}
171171
@destruct (γ, γ̄) = c.ȳ̄((Δ′′, c.Δ′′′))
172172
(Base.tail(γ), ∂⃖rruleD{N, O}(γ̄, c.β̄))
173173
end
174174

175-
@Base.aggressive_constprop function (d::∂⃖rruleD{N, O})(Δ⁴...) where {N, O}
175+
@Base.constprop :aggressive function (d::∂⃖rruleD{N, O})(Δ⁴...) where {N, O}
176176
(δ₁, δ₂), δ̄ = d.γ̄(ZeroTangent(), Δ⁴...)
177177
(δ₁, ∂⃖rruleA{N, O+1}(d.β̄ , δ₂, δ̄ ))
178178
end
179179

180180
# Terminal cases
181-
@Base.aggressive_constprop function (c::∂⃖rruleB{N, N})(Δ′...) where {N}
181+
@Base.constprop :aggressive function (c::∂⃖rruleB{N, N})(Δ′...) where {N}
182182
@destruct (Δ′′′, β) = c.ᾱ(Δ′)
183183
(β, ∂⃖rruleC{N, N}(c.ȳ̄, Δ′′′, nothing))
184184
end
185-
@Base.aggressive_constprop (c::∂⃖rruleC{N, N})(Δ′′) where {N} =
185+
@Base.constprop :aggressive (c::∂⃖rruleC{N, N})(Δ′′) where {N} =
186186
Base.tail(c.ȳ̄((Δ′′, c.Δ′′′)))
187187
(::∂⃖rruleD{N, N})(Δ...) where {N} = error("Should not be reached")
188188

@@ -255,17 +255,17 @@ function ChainRulesCore.rrule(::KwFunc, kwargs, f, args...)
255255
end
256256
end
257257

258-
@Base.aggressive_constprop function ChainRulesCore.rrule(::typeof(Core.getfield), s, field::Symbol)
258+
@Base.constprop :aggressive function ChainRulesCore.rrule(::typeof(Core.getfield), s, field::Symbol)
259259
getfield(s, field), let P = typeof(s)
260-
@Base.aggressive_constprop Δ->begin
260+
@Base.constprop :aggressive Δ->begin
261261
nt = NamedTuple{(field,)}((Δ,))
262262
(NoTangent(), Tangent{P, typeof(nt)}(nt), NoTangent())
263263
end
264264
end
265265
end
266266

267267
struct ∂⃖getfield{n, f}; end
268-
@Base.aggressive_constprop function (::∂⃖getfield{n, f})(Δ) where {n,f}
268+
@Base.constprop :aggressive function (::∂⃖getfield{n, f})(Δ) where {n,f}
269269
if @generated
270270
return Expr(:call, tuple, NoTangent(),
271271
Expr(:call, tuple, (i == f ? :(Δ) : ZeroTangent() for i = 1:n)...),
@@ -279,31 +279,31 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end
279279
EvenOddEven{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddEven{O, P, F, G}(f, g)
280280
struct EvenOddOdd{O, P, F, G}; f::F; g::G; end
281281
EvenOddOdd{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddOdd{O, P, F, G}(f, g)
282-
@Base.aggressive_constprop (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g))
283-
@Base.aggressive_constprop (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g))
284-
@Base.aggressive_constprop (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ)
282+
@Base.constprop :aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g))
283+
@Base.constprop :aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g))
284+
@Base.constprop :aggressive (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ)
285285

286286

287-
@Base.aggressive_constprop function (::∂⃖{N})(::typeof(Core.getfield), s, field::Int) where {N}
287+
@Base.constprop :aggressive function (::∂⃖{N})(::typeof(Core.getfield), s, field::Int) where {N}
288288
getfield(s, field), EvenOddOdd{1, c_order(N)}(
289289
∂⃖getfield{nfields(s), field}(),
290-
@Base.aggressive_constprop (_, Δ, _)->getfield(Δ, field))
290+
@Base.constprop :aggressive (_, Δ, _)->getfield(Δ, field))
291291
end
292292

293-
@Base.aggressive_constprop function (::∂⃖{N})(::typeof(Base.getindex), s::Tuple, field::Int) where {N}
293+
@Base.constprop :aggressive function (::∂⃖{N})(::typeof(Base.getindex), s::Tuple, field::Int) where {N}
294294
getfield(s, field), EvenOddOdd{1, c_order(N)}(
295295
∂⃖getfield{nfields(s), field}(),
296-
@Base.aggressive_constprop (_, Δ, _)->lifted_getfield(Δ, field))
296+
@Base.constprop :aggressive (_, Δ, _)->lifted_getfield(Δ, field))
297297
end
298298

299299
function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N}
300300
getfield(s, field), let P = typeof(s)
301301
EvenOddOdd{1, c_order(N)}(
302-
(@Base.aggressive_constprop Δ->begin
302+
(@Base.constprop :aggressive Δ->begin
303303
nt = NamedTuple{(field,)}((Δ,))
304304
(NoTangent(), Tangent{P, typeof(nt)}(nt), NoTangent())
305305
end),
306-
(@Base.aggressive_constprop (_, Δs, _)->begin
306+
(@Base.constprop :aggressive (_, Δs, _)->begin
307307
isa(Δs, Union{ZeroTangent, NoTangent}) ? Δs : getfield(ChainRulesCore.backing(Δs), field)
308308
end))
309309
end
@@ -313,13 +313,13 @@ end
313313
function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N}
314314
getindex(a, inds...), let
315315
EvenOddOdd{1, c_order(N)}(
316-
(@Base.aggressive_constprop Δ->begin
316+
(@Base.constprop :aggressive Δ->begin
317317
Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...)
318318
BB = zero(a)
319319
BB[inds...] = Δ
320320
(NoTangent(), BB, map(x->NoTangent(), inds)...)
321321
end),
322-
(@Base.aggressive_constprop (_, Δ, _)->begin
322+
(@Base.constprop :aggressive (_, Δ, _)->begin
323323
getindex(Δ, inds...)
324324
end))
325325
end
@@ -355,15 +355,15 @@ end
355355

356356
struct ApplyOdd{O, P}; u; ∂⃖f; end
357357
struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end
358-
@Base.aggressive_constprop function (a::ApplyOdd{O, P})(Δ) where {O, P}
358+
@Base.constprop :aggressive function (a::ApplyOdd{O, P})(Δ) where {O, P}
359359
r, ∂⃖∂⃖f = a.∂⃖f(Δ)
360360
(a.u(r), ApplyEven{plus1(O), P}(a.u, ∂⃖∂⃖f))
361361
end
362-
@Base.aggressive_constprop function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P}
362+
@Base.constprop :aggressive function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P}
363363
r, ∂⃖∂⃖∂⃖f = Core._apply_iterate(iterate, a.∂⃖∂⃖f, (ff,), args...)
364364
(r, ApplyOdd{plus1(O), P}(a.u, ∂⃖∂⃖∂⃖f))
365365
end
366-
@Base.aggressive_constprop function (a::ApplyOdd{O, O})(Δ) where {O}
366+
@Base.constprop :aggressive function (a::ApplyOdd{O, O})(Δ) where {O}
367367
r = a.∂⃖f(Δ)
368368
a.u(r)
369369
end
@@ -381,7 +381,7 @@ end
381381
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, c_order(N)}()
382382
end
383383

384-
@Base.aggressive_constprop lifted_getfield(x, s) = getfield(x, s)
384+
@Base.constprop :aggressive lifted_getfield(x, s) = getfield(x, s)
385385
lifted_getfield(x::ZeroTangent, s) = ZeroTangent()
386386
lifted_getfield(x::NoTangent, s) = NoTangent()
387387

0 commit comments

Comments
 (0)