Skip to content

Commit 0a98045

Browse files
authored
Merge pull request #49 from JuliaDiff/sds/fix_constprop
adapt to JuliaLang/julia#42125
2 parents 2059a35 + 9709486 commit 0a98045

File tree

4 files changed

+46
-320
lines changed

4 files changed

+46
-320
lines changed

Manifest.toml

Lines changed: 0 additions & 274 deletions
This file was deleted.

src/runtime.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
using ChainRulesCore
22

3-
@Base.aggressive_constprop accum(a, b) = a + b
4-
@Base.aggressive_constprop accum(a::Tuple, b::Tuple) = map(accum, a, b)
5-
@Base.aggressive_constprop @generated function accum(x::NamedTuple, y::NamedTuple)
3+
@Base.constprop :aggressive accum(a, b) = a + b
4+
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)
5+
@Base.constprop :aggressive @generated function accum(x::NamedTuple, y::NamedTuple)
66
fnames = union(fieldnames(x), fieldnames(y))
77
gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent())
88
grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent())
99
Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...)
1010
end
11-
@Base.aggressive_constprop accum(a, b, c, args...) = accum(accum(a, b), c, args...)
12-
@Base.aggressive_constprop accum(a::NoTangent, b) = b
13-
@Base.aggressive_constprop accum(a, b::NoTangent) = a
14-
@Base.aggressive_constprop accum(a::NoTangent, b::NoTangent) = NoTangent()
11+
@Base.constprop :aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...)
12+
@Base.constprop :aggressive accum(a::NoTangent, b) = b
13+
@Base.constprop :aggressive accum(a, b::NoTangent) = a
14+
@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent()

src/stage1/forward.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,37 +134,37 @@ end
134134
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)
135135

136136
# Special case rules for performance
137-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
137+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
138138
s = primal(s)
139139
TangentBundle{N}(getfield(primal(x), s),
140140
map(x->lifted_getfield(x, s), x.partials))
141141
end
142142

143-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
143+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
144144
s = primal(s)
145145
TaylorBundle{N}(getfield(primal(x), s),
146146
map(y->lifted_getfield(y, s), x.coeffs))
147147
end
148148

149-
@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
149+
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
150150
x.tup[primal(s)]
151151
end
152152

153-
@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
153+
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
154154
x.tup[Base.fieldindex(B, primal(s))]
155155
end
156156

157-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
157+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
158158
s = primal(s)
159159
TangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
160160
map(x->lifted_getfield(x, s), x.partials))
161161
end
162162

163-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
163+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
164164
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.partial)
165165
end
166166

167-
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U}
167+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U}
168168
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s), primal(inbounds)), x.partial)
169169
end
170170

0 commit comments

Comments
 (0)