Skip to content

Commit 867acc5

Browse files
committed
Merge branch 'ox/type_stable_zero' into ox/mutation2
2 parents b507348 + 301928d commit 867acc5

File tree

10 files changed

+49
-39
lines changed

10 files changed

+49
-39
lines changed

src/codegen/forward.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function fwd_transform!(ci, mi, nargs, N)
3434
args = map(stmt.args) do stmt
3535
emit!(mapstmt!(stmt))
3636
end
37-
return Expr(:call, Core._apply_iterate, FwdIterate(ZeroBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
37+
return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
3838
elseif isa(stmt, SSAValue)
3939
return SSAValue(ssa_mapping[stmt.id])
4040
elseif isa(stmt, Core.SlotNumber)
@@ -62,7 +62,7 @@ function fwd_transform!(ci, mi, nargs, N)
6262
# Always disable `@inbounds`, as we don't actually know if the AD'd
6363
# code is truly `@inbounds` or not.
6464
elseif isexpr(stmt, :boundscheck)
65-
return ZeroBundle{N}(true)
65+
return DNEBundle{N}(true)
6666
else
6767
# Fallback case, for literals.
6868
# If it is an Expr, then it is not a literal

src/codegen/forward_demand.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,6 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met
363363
rt = CC._ir_abstract_constant_propagation(interp, irsv)
364364

365365
ir = compact!(ir)
366-
366+
367367
return ir
368368
end

src/higher_fwd_rules.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@ end
1919

2020
jeval(j, x) = j(x)
2121
for f in (sin, cos, exp)
22-
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N}
22+
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N}
2323
njet(Val{N}(), primal(fb), primal(x))(x)
2424
end
25-
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::ZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M}
25+
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::AbstractZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M}
2626
∂⃖ₙ(jeval, njet(Val{N+M}(), primal(fb), primal(x)), x)
2727
end
2828
end
2929

3030
# TODO: It's a bit embarassing that we need to write these out, but currently the
3131
# compiler is not strong enough to automatically lift the frule. Let's hope we
3232
# can delete these in the near future.
33-
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
33+
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
3434
TaylorBundle{N}(primal(a) + primal(b),
3535
map(+, a.tangent.coeffs, b.tangent.coeffs))
3636
end
3737

38-
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::ZeroBundle{N}) where {N}
38+
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::AbstractZeroBundle{N}) where {N}
3939
TaylorBundle{N}(primal(a) + primal(b), a.tangent.coeffs)
4040
end
4141

42-
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
42+
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
4343
TaylorBundle{N}(primal(a) - primal(b),
4444
map(-, a.tangent.coeffs, b.tangent.coeffs))
4545
end

src/stage1/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99

1010
n_getfield(∂ₙ::∂☆{N}, b::ATB{N}, x::Union{Symbol, Int}) where {N} = ∂ₙ(ZeroBundle{N}(getfield), b, ZeroBundle{N}(x))
1111

12-
function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
12+
function (∂ₙ::∂☆{N})(zc::AbstractZeroBundle{N, typeof(copy)},
1313
bc::ATB{N, <:Broadcasted}) where {N}
1414
bc = ∂ₙ(ZeroBundle{N}(Broadcast.flatten), bc)
1515
args = n_getfield(∂ₙ, bc, :args)

src/stage1/forward.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,12 @@ struct ∂☆shuffle{N}; end
9898

9999
function shuffle_base(r)
100100
(primal, dual) = r
101-
if isa(dual, Union{NoTangent, ZeroTangent})
101+
if dual isa NoTangent
102102
UniformBundle{1}(primal, dual)
103103
else
104+
if dual isa ZeroTangent # Normalize zero for type-stability reasons
105+
dual = zero_tangent(primal)
106+
end
104107
TaylorBundle{1}(primal, (dual,))
105108
end
106109
end
@@ -193,25 +196,25 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
193196
end
194197
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
195198

196-
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
199+
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
197200
∂vararg{N}()(map(FwdMap(f), destructure(tup))...)
198201
end
199202

200-
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
203+
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
201204
# TODO: This could do an inplace map! to avoid the extra rebundling
202205
rebundle(map(FwdMap(f), map(unbundle, args)...))
203206
end
204207

205-
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N}
208+
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N}
206209
∂☆recurse{N}()(ZeroBundle{N, typeof(map)}(map), f, args...)
207210
end
208211

209212

210-
function (::∂☆{N})(f::ZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
213+
function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
211214
ifelse(arg.primal, args...)
212215
end
213216

214-
function (::∂☆{N})(f::ZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
217+
function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
215218
Core.ifelse(arg.primal, args...)
216219
end
217220

@@ -233,48 +236,48 @@ end
233236
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
234237
end
235238

236-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
239+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
237240
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
238241
end
239242

240243

241-
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N}
244+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N}
242245
r = iterate(destructure(t))
243246
r === nothing && return ZeroBundle{N}(nothing)
244247
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
245248
end
246249

247-
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
250+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
248251
r = iterate(destructure(t), primal(a), map(primal, args)...)
249252
r === nothing && return ZeroBundle{N}(nothing)
250253
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
251254
end
252255

253-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N}
256+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N}
254257
r = Base.indexed_iterate(destructure(t), primal(i))
255258
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
256259
end
257260

258-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
261+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
259262
r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...)
260263
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
261264
end
262265

263-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N}
266+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N}
264267
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
265268
end
266269

267-
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N}
270+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::AbstractZeroBundle) where {N}
268271
field_ind = primal(i)
269272
the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N)
270273
TaylorBundle{N}(primal(t)[field_ind], the_partials)
271274
end
272275

273-
function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
276+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
274277
DNEBundle{N}(typeof(primal(x)))
275278
end
276279

277-
function (this::∂☆{N})(f::ZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N}
280+
function (this::∂☆{N})(f::AbstractZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N}
278281
ff = primal(f)
279282
if ff === Base.not_int
280283
DNEBundle{N}(ff(map(primal, args)...))

src/stage1/mixed.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,26 +70,26 @@ function (f::FwdIterate)(arg::ATB{N}, st) where {N}
7070
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
7171
end
7272
73-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
73+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
7474
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
7575
end
7676
=#
7777

78-
function (this::∂⃖{N})(that::∂☆{M}, ::ZeroBundle{M, typeof(Core._apply_iterate)},
78+
function (this::∂⃖{N})(that::∂☆{M}, ::AbstractZeroBundle{M, typeof(Core._apply_iterate)},
7979
iterate, f, args::ATB{M, <:Tuple}...) where {N, M}
8080
@assert primal(iterate) === Base.iterate
8181
x, ∂⃖f = Core._apply_iterate(FwdIterate(iterate), this, (that, f), args...)
8282
return x, ApplyOdd{1, c_order(N)}(UnApply{map(x->length(primal(x)), args)}(), ∂⃖f)
8383
end
8484

8585

86-
function ChainRules.rrule(∂::∂☆{N}, m::ZeroBundle{N, typeof(map)}, p::ZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N}
86+
function ChainRules.rrule(∂::∂☆{N}, m::AbstractZeroBundle{N, typeof(map)}, p::AbstractZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N}
8787
(m, p, A, B), Δ->(NoTangent(), NoTangent(), NoTangent(), Δ, Δ)
8888
end
8989

9090
mapev_unbundled(_, js, a) = rebundle(mapev(js, unbundle(a)))
91-
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::ZeroBundle{M, typeof(map)},
92-
f::ZeroBundle{M}, a::ATB{M, <:Array}) where {N, M}
91+
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::AbstractZeroBundle{M, typeof(map)},
92+
f::AbstractZeroBundle{M}, a::ATB{M, <:Array}) where {N, M}
9393
@assert Base.issingletontype(typeof(primal(f)))
9494
js = map(primal(a)) do x
9595
∂f = ∂☆{N+M}()(ZeroBundle{N+M}(primal(f)),

src/stage1/recurse_fwd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ struct ∂☆new{N}; end
1515
function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...)
1616
primal_args = map(primal, xs)
1717
the_primal = _construct(B, primal_args)
18-
@info "∂☆new{1}"
1918
tangent_tup = map(first_partial, xs)
2019
the_partial = if B<:Tuple
2120
Tangent{B, typeof(tangent_tup)}(tangent_tup)
@@ -24,13 +23,14 @@ function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...)
2423
tangent_nt = NamedTuple{names}(tangent_tup)
2524
StructuralTangent{B}(tangent_nt)
2625
end
26+
@show typeof(the_partial)
27+
#TODO: I think we need https://github.com/JuliaDiff/Diffractor.jl/pull/236/files here
2728
return TaylorBundle{1, B}(the_primal, (the_partial,))
2829
end
2930

3031
function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N}
3132
primal_args = map(primal, xs)
3233
the_primal = _construct(B, primal_args)
33-
@info "∂☆new{N}"
3434
the_partials = ntuple(Val{N}()) do ii
3535
tangent_tup = map(x->partial(x, ii), xs)
3636
tangent = if B<:Tuple

src/tangent.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,11 @@ Creates a bundle with a zero tangent.
427427
"""
428428
struct zero_bundle{N} end
429429
function (::zero_bundle{N})(primal) where N
430-
if zero_tangent(primal) isa ZeroTangent
431-
return ZeroBundle{N}(primal)
430+
if zero_tangent(primal) isa AbstractZero
431+
return UniformBundle{N}(primal, zero_tangent(primal) )
432432
else
433+
# Note: it is important that zero_tangent(primal) is called in ntuple
434+
# so it gets distrinct values for each order, so it doesn't alias if mutated.
433435
return TaylorBundle{N}(primal, ntuple(_->zero_tangent(primal), N))
434436
end
435437
end

test/stage2_fwd.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,37 @@ module stage2_fwd
66
@test sin′(1.0) == cos(1.0)
77
end
88
let sin′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(mysin), Float64}, 2)
9-
@test isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64})
9+
# This broke some time between 1.10 and 1.11-DEV.10001
10+
@test_broken isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64})
1011
@test sin′′(1.0) == -sin(1.0)
1112
end
1213

1314
myminus(a, b) = a - b
1415
self_minus(a) = myminus(a, a)
1516
ChainRulesCore.@scalar_rule myminus(x, y) (true, -1)
1617
let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64})
17-
@test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64})
18+
# This broke some time between 1.10 and 1.11-DEV.10001
19+
@test_broken isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64})
1820
@test self_minus′(1.0) == 0.
1921
end
2022
ChainRulesCore.@scalar_rule myminus(x, y) (true, true) # frule for `x - y`
2123
let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64})
22-
@test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64})
24+
# This broke some time between 1.10 and 1.11-DEV.10001
25+
@test_broken isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64})
2326
@test self_minus′(1.0) == 2.
2427
end
2528

2629
myminus2(a, b) = a - b
2730
self_minus2(a) = myminus2(a, a)
2831
let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64})
29-
@test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64})
32+
# This broke some time between 1.10 and 1.11-DEV.10001
33+
@test_broken isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64})
3034
@test self_minus2′(1.0) == 0.
3135
end
3236
ChainRulesCore.@scalar_rule myminus2(x, y) (true, true) # frule for `x - y`
3337
let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64})
34-
@test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64})
38+
# This broke some time between 1.10 and 1.11-DEV.10001
39+
@test_broken isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64})
3540
@test self_minus2′(1.0) == 2.
3641
end
3742

test/tangent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ end
126126
zero_bundle = Diffractor.zero_bundle
127127

128128
tup_zb = zero_bundle{2}()((1, 0))
129-
@test tup_zb isa ZeroBundle{2}
129+
@test tup_zb isa Diffractor.AbstractTangentBundle{2}
130130
@test iszero(tup_zb[TaylorTangentIndex(1)])
131131
@test iszero(tup_zb[TaylorTangentIndex(2)])
132132

0 commit comments

Comments
 (0)