Skip to content

Commit e9cc221

Browse files
committed
overhaul zero_tangent and MutableTangent for type stability
1 parent 5574691 commit e9cc221

File tree

3 files changed

+177
-99
lines changed

3 files changed

+177
-99
lines changed

src/tangent_types/abstract_zero.jl

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,22 +111,40 @@ function zero_tangent end
111111
zero_tangent(x::Number) = zero(x)
112112

113113
@generated function zero_tangent(primal)
114-
has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples
115114
zfield_exprs = map(fieldnames(primal)) do fname
116-
fval = if isdefined(primal, fname)
117-
Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname)))
118-
else
119-
ZeroTangent()
120-
end
115+
fval = :(
116+
if isdefined(primal, $(QuoteNode(fname)))
117+
zero_tangent(getfield(primal, $(QuoteNode(fname))))
118+
else
119+
# This is going to be potentially bad, but that's what they get for not giving us a primal
120+
# This will never me mutated inplace, rather it will alway be replaced with an actual value first
121+
ZeroTangent()
122+
end
123+
)
121124
Expr(:kw, fname, fval)
122125
end
123-
backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...))
124-
return :($MutableTangent{$primal}($backing_expr))
126+
127+
return if has_mutable_tangent(primal)
128+
any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
129+
# If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
130+
fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))
131+
Expr(:kw, fname, fdef)
132+
end
133+
:($MutableTangent{$primal}(
134+
$(Expr(:tuple, Expr(:parameters, any_mask...))),
135+
$(Expr(:tuple, Expr(:parameters, zfield_exprs...)))
136+
))
137+
else
138+
:($Tangent{$primal}($(Expr(:parameters, zfield_exprs...))))
139+
end
125140
end
126141

142+
zero_tangent(primal::Tuple) = Tangent{typeof(primal)}(map(zero_tangent, primal)...)
143+
127144
function zero_tangent(x::Array{P,N}) where {P,N}
128-
(isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) &&
145+
if (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x)))
129146
return map(zero_tangent, x)
147+
end
130148

131149
# Now we need to handle nonfully assigned arrays
132150
# see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265
@@ -139,9 +157,8 @@ function zero_tangent(x::Array{P,N}) where {P,N}
139157
return y
140158
end
141159

160+
# Sad heauristic methods we need because of unassigned values
142161
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
143-
function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N}
144-
return Array{guess_zero_tangent_type(T),N}
145-
end
146-
guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634
147-
# TODO: we might be able to do better than this. even without.
162+
guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T)))
163+
guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = return Array{guess_zero_tangent_type(T),N}
164+
guess_zero_tangent_type(T::Type)= Any

src/tangent_types/structural_tangent.jl

Lines changed: 99 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,90 @@ as an object with mirroring fields.
1313
"""
1414
abstract type StructuralTangent{P} <: AbstractTangent end
1515

16+
"""
17+
Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent
18+
19+
This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`.
20+
`P` is the the corresponding primal type that this is a tangent for.
21+
22+
`Tangent{P}` should have fields (technically properties), that match to a subset of the
23+
fields of the primal type; and each should be a tangent type matching to the primal
24+
type of that field.
25+
Fields of the P that are not present in the Tangent are treated as `Zero`.
26+
27+
`T` is an implementation detail representing the backing data structure.
28+
For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
29+
It should not be passed in by user.
30+
31+
For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly
32+
to for a tuple.
33+
For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values
34+
via `tangent.fieldname`.
35+
Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`.
36+
To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref)
37+
function is provided.
38+
"""
39+
struct Tangent{P,T} <: StructuralTangent{P}
40+
# Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
41+
# (but potentially a different one, as it doesn't contain tangents)
42+
backing::T
43+
44+
function Tangent{P,T}(backing) where {P,T}
45+
if P <: Tuple
46+
T <: Tuple || _backing_error(P, T, Tuple)
47+
elseif P <: AbstractDict
48+
T <: AbstractDict || _backing_error(P, T, AbstractDict)
49+
elseif P === Any # can be anything
50+
else # Any other struct (including NamedTuple)
51+
T <: NamedTuple || _backing_error(P, T, NamedTuple)
52+
end
53+
return new(backing)
54+
end
55+
end
56+
57+
"""
58+
MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent
59+
60+
This type represents the tangent to a mutable struct.
61+
It itself is also mutable.
62+
63+
!!! warning Exprimental
64+
MutableTangent is an experimental feature, and is part of the mutation support featureset.
65+
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
66+
Exactly how it should be used (e.g. is it forward-mode only?)
67+
68+
!!! warning Do not directly mess with the tangent backing data
69+
It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values.
70+
However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is).
71+
If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this.
72+
"""
73+
struct MutableTangent{P,F} <: StructuralTangent{P}
74+
backing::F
75+
76+
function MutableTangent{P}(fieldvals) where P
77+
backing = map(Ref, fieldvals)
78+
return new{P, typeof(backing)}(backing)
79+
end
80+
function MutableTangent{P}(
81+
any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names}
82+
) where {names, P}
83+
84+
backing = map(any_mask, fvals) do isany, fval
85+
ref = if isany
86+
Ref{Any}
87+
else
88+
Ref
89+
end
90+
return ref(fval)
91+
end
92+
return new{P, typeof(backing)}(backing)
93+
end
94+
end
95+
96+
####################################################################
97+
# StructuralTangent Common
98+
99+
16100
function StructuralTangent{P}(nt::NamedTuple) where {P}
17101
if has_mutable_tangent(P)
18102
return MutableTangent{P}(nt)
@@ -21,6 +105,7 @@ function StructuralTangent{P}(nt::NamedTuple) where {P}
21105
end
22106
end
23107

108+
24109
has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(P) > 0)
25110

26111

@@ -40,6 +125,9 @@ end
40125
Base.iszero(t::StructuralTangent) = all(iszero, backing(t))
41126

42127
function Base.map(f, tangent::StructuralTangent{P}) where {P}
128+
#TODO: is it even useful to support this on MutableTangents?
129+
#TODO: we implictly assume only linear `f` are called and that it is safe to ignore noncanonical Zeros
130+
# This feels like a fair assumption since all normal operations on tangents are linear
43131
L = propertynames(backing(tangent))
44132
vals = map(f, Tuple(backing(tangent)))
45133
named_vals = NamedTuple{L,typeof(vals)}(vals)
@@ -63,7 +151,8 @@ primal types.
63151
backing(x::Tuple) = x
64152
backing(x::NamedTuple) = x
65153
backing(x::Dict) = x
66-
backing(x::StructuralTangent) = getfield(x, :backing)
154+
backing(x::Tangent) = getfield(x, :backing)
155+
backing(x::MutableTangent) = map(getindex, getfield(x, :backing))
67156

68157
# For generic structs
69158
function backing(x::T)::NamedTuple where {T}
@@ -206,46 +295,8 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P}
206295
return println(io)
207296
end
208297

209-
"""
210-
Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent
211-
212-
This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`.
213-
`P` is the the corresponding primal type that this is a tangent for.
214-
215-
`Tangent{P}` should have fields (technically properties), that match to a subset of the
216-
fields of the primal type; and each should be a tangent type matching to the primal
217-
type of that field.
218-
Fields of the P that are not present in the Tangent are treated as `Zero`.
219-
220-
`T` is an implementation detail representing the backing data structure.
221-
For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
222-
It should not be passed in by user.
223-
224-
For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly
225-
to for a tuple.
226-
For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values
227-
via `tangent.fieldname`.
228-
Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`.
229-
To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref)
230-
function is provided.
231-
"""
232-
struct Tangent{P,T} <: StructuralTangent{P}
233-
# Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
234-
# (but potentially a different one, as it doesn't contain tangents)
235-
backing::T
236-
237-
function Tangent{P,T}(backing) where {P,T}
238-
if P <: Tuple
239-
T <: Tuple || _backing_error(P, T, Tuple)
240-
elseif P <: AbstractDict
241-
T <: AbstractDict || _backing_error(P, T, AbstractDict)
242-
elseif P === Any # can be anything
243-
else # Any other struct (including NamedTuple)
244-
T <: NamedTuple || _backing_error(P, T, NamedTuple)
245-
end
246-
return new(backing)
247-
end
248-
end
298+
#######################################
299+
# immutable Tangent
249300

250301
function Tangent{P}(; kwargs...) where {P}
251302
backing = (; kwargs...) # construct as NamedTuple
@@ -401,46 +452,19 @@ canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent
401452
canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent
402453
canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent
403454

404-
405-
"""
406-
MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent
407-
408-
This type represents the tangent to a mutable struct.
409-
It itself is also mutable.
410-
411-
!!! warning Exprimental
412-
MutableTangent is an experimental feature, and is part of the mutation support featureset.
413-
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
414-
Exactly how it should be used (e.g. is it forward-mode only?)
415-
416-
!!! warning Do not directly mess with the tangent backing data
417-
It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values.
418-
However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is).
419-
If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this.
420-
"""
421-
mutable struct MutableTangent{P} <: StructuralTangent{P}
422-
#TODO: we may want to absolutely lock the type of this down
423-
backing::NamedTuple
424-
end
455+
###################################################
456+
# MutableTangent
425457

426458
MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs))
427459

428-
Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(backing(tangent), idx)
429-
Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(backing(tangent), idx) # break ambig
460+
ref_backing(t::MutableTangent) = getfield(t, :backing)
430461

431-
function Base.setproperty!(tangent::MutableTangent, name::Symbol, x)
432-
new_backing = Base.setindex(backing(tangent), x, name)
433-
setfield!(tangent, :backing, new_backing)
434-
return x
435-
end
462+
Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(ref_backing(tangent), idx)[]
463+
Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(ref_backing(tangent), idx)[] # break ambig
436464

437-
function Base.setproperty!(tangent::MutableTangent, idx::Int, x)
438-
# needed due to https://github.com/JuliaLang/julia/issues/43155
439-
name = idx2sym(backing(tangent), idx)
440-
return setproperty!(tangent, name, x)
441-
end
465+
Base.setproperty!(tangent::MutableTangent, name::Symbol, x) = getproperty(ref_backing(tangent), name)[] = x
466+
Base.setproperty!(tangent::MutableTangent, idx::Int, x) = getproperty(ref_backing(tangent), idx)[] = x # break ambig
442467

443-
idx2sym(::NamedTuple{names}, idx) where names = names[idx]
444468

445469
Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h)
446470
function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2}

test/tangent_types/abstract_zero.jl

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@
162162
end
163163

164164
@testset "zero_tangent" begin
165+
@test zero_tangent(1) === 0
166+
@test zero_tangent(1.0) === 0.0
165167
mutable struct MutDemo
166168
x::Float64
167169
end
@@ -171,34 +173,34 @@ end
171173
@test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo}
172174
@test iszero(zero_tangent(MutDemo(1.5)))
173175

174-
@test zero_tangent((; a=1)) isa ZeroTangent
175-
@test zero_tangent(Demo(1.2)) isa ZeroTangent
176-
177-
@test zero_tangent(1) === 0
178-
@test zero_tangent(1.0) === 0.0
176+
@test zero_tangent((; a=1)) isa Tangent{typeof((;a=1))}
177+
@test zero_tangent(Demo(1.2)) isa Tangent{Demo}
178+
@test zero_tangent(Demo(1.2)).x === 0.0
179179

180180
@test zero_tangent([1.0, 2.0]) == [0.0, 0.0]
181181
@test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]]
182182

183+
@test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0)
184+
183185
@testset "undef elements Vector" begin
184186
x = Vector{Vector{Float64}}(undef, 3)
185187
x[2] = [1.0, 2.0]
186188
dx = zero_tangent(x)
187189
@test dx isa Vector{Vector{Float64}}
188190
@test length(dx) == 3
189-
@test !isassigned(dx, 1)
191+
@test !isassigned(dx, 1) # We may reconsider this later
190192
@test dx[2] == [0.0, 0.0]
191-
@test !isassigned(dx, 3)
193+
@test !isassigned(dx, 3) # We may reconsider this later
192194

193195
a = Vector{MutDemo}(undef, 3)
194196
a[2] = MutDemo(1.5)
195197
da = zero_tangent(a)
196-
@test !isassigned(da, 1)
198+
@test !isassigned(da, 1) # We may reconsider this later
197199
@test iszero(da[2])
198-
@test !isassigned(da, 3)
200+
@test !isassigned(da, 3) # We may reconsider this later
199201

200202
db = zero_tangent(Vector{MutDemo}(undef, 3))
201-
@test all(ii -> !isassigned(db, ii), eachindex(db))
203+
@test all(ii -> !isassigned(db, ii), eachindex(db)) # We may reconsider this later
202204
@test length(db) == 3
203205
@test db isa Vector
204206
end
@@ -217,5 +219,40 @@ end
217219
@test iszero(dy.intro)
218220
@test iszero(dy.contents)
219221
@test (dy.contents = 2.0) == 2.0 # should be assignable
222+
223+
mutable struct MyPartiallyDefinedStructWithAnys
224+
intro::Float64
225+
contents::Any
226+
MyPartiallyDefinedStructWithAnys(x) = new(x)
227+
end
228+
dy = zero_tangent(MyPartiallyDefinedStructWithAnys(1.5))
229+
@test iszero(dy.intro)
230+
@test iszero(dy.contents)
231+
@test dy.contents === ZeroTangent() # we just don't know anything about this data
232+
@test (dy.contents = 2.0) == 2.0 # should be assignable
233+
@test (dy.contents = [2.0, 4.0]) == [2.0, 4.0] # should be assignable to different values
234+
235+
mutable struct MyStructWithNonConcreteFields
236+
x::Any
237+
y::Union{Float64, Vector{Float64}}
238+
z::AbstractVector
239+
end
240+
d = zero_tangent(MyStructWithNonConcreteFields(1.0, 2.0, [3.0]))
241+
@test iszero(d.x)
242+
d.x = Tangent{Base.RefValue{Float64}}(x=1.5)
243+
@test d.x == Tangent{Base.RefValue{Float64}}(x=1.5) #should be assignable
244+
d.x=2.4
245+
@test d.x == 2.4 #should be assignable
246+
@test iszero(d.y)
247+
d.y=2.4
248+
@test d.y == 2.4 #should be assignable
249+
d.y=[2.4]
250+
@test d.y == [2.4] #should be assignable
251+
@test iszero(d.z)
252+
d.z = [1.0, 2.0]
253+
@test d.z == [1.0, 2.0]
254+
d.z = @view [2.0,3.0,4.0][1:2]
255+
@test d.z == [2.0, 3.0]
256+
@test d.z isa SubArray
220257
end
221258
end

0 commit comments

Comments
 (0)