Skip to content

Commit 18eea00

Browse files
committed
use unthunk rather than extern for thunks
more composite stuff Work on constructing wip break things up make elementwise add Fixup tests mostly fix comment how about we don't just stackoverflow when we do getproperty add good tests add tests (pun intended) more tests fix scaling tests add error Update src/composite_core.jl Co-Authored-By: Nick Robinson <npr251@gmail.com> Update src/differentials.jl Co-Authored-By: Nick Robinson <npr251@gmail.com> Update src/differential_arithmetic.jl Co-Authored-By: Mathieu Besançon <mathieu.besancon@gmail.com> Update src/differential_arithmetic.jl resolve comments from PR add tests on indexing, iterating, and properties fix eltype test extern check on allocations and backing behavour test about internals delete ambi tests because we intentionally have some now use P rather than Primal for typevar Update src/composite_core.jl sort importand and exports remove extra blank lines move comment Fix name tuple construction
1 parent 166a65c commit 18eea00

File tree

6 files changed

+359
-108
lines changed

6 files changed

+359
-108
lines changed

src/ChainRulesCore.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
module ChainRulesCore
2-
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
2+
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
33

44
export frule, rrule
5-
export wirtinger_conjugate, wirtinger_primal, refine_differential
5+
export refine_differential, wirtinger_conjugate, wirtinger_primal
66
export @scalar_rule, @thunk
77
export extern, store!, unthunk
8-
export Wirtinger, Zero, One, DoesNotExist, Thunk, InplaceableThunk
8+
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Wirtinger, Zero
99
export NO_FIELDS
1010

1111
include("differentials.jl")
12+
include("composite_core.jl")
1213
include("differential_arithmetic.jl")
14+
1315
include("operations.jl")
1416
include("rules.jl")
1517
include("rule_definition_tools.jl")

src/composite_core.jl

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
struct PrimalAdditionFailedException{P} <: Exception
2+
primal::P
3+
differential::Composite{P}
4+
original::Exception
5+
end
6+
7+
function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P}
8+
println(io, "Could not construct $P after addition.")
9+
println(io, "This probably means no default constructor is defined.")
10+
println(io, "Either define a default constructor")
11+
printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue)
12+
println(io, "\nor overload")
13+
printstyled(io,
14+
"ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))";
15+
color=:blue
16+
)
17+
println(io, "\nor overload")
18+
printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue)
19+
println(io, "\nOriginal Exception:")
20+
printstyled(io, err.original; color=:yellow)
21+
println(io)
22+
end
23+
24+
"""
25+
backing(x)
26+
27+
Accesses the backing field of a `Composite`,
28+
or destructures any other composite type into a `NamedTuple`.
29+
Identity function on `Tuple`. and `NamedTuple`s.
30+
31+
This is an internal function used to simplify operations between `Composite`s and the
32+
primal types.
33+
"""
34+
backing(x::Tuple) = x
35+
backing(x::NamedTuple) = x
36+
backing(x::Composite) = getfield(x, :backing)
37+
38+
function backing(x::T)::NamedTuple where T
39+
!isstructtype(T) && throw(DomainError(T, "backing can only be use on composite types"))
40+
nfields = fieldcount(T)
41+
names = ntuple(ii->fieldname(T, ii), nfields)
42+
types = ntuple(ii->fieldtype(T, ii), nfields)
43+
44+
if @generated
45+
# @btime (()->ChainRulesCore.backing(Foo(1.0, 2.0)))()
46+
## 5.590 ns (1 allocation: 32 bytes)
47+
48+
vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...)
49+
return :(NamedTuple{$names, Tuple{$(types...)}}($vals))
50+
else
51+
vals = ntuple(ii->getfield(x, ii), nfields)
52+
return NamedTuple{names, Tuple{types...}}(vals)
53+
end
54+
end
55+
56+
"""
57+
construct(::Type{T}, fields::[NamedTuple|Tuple])
58+
59+
Constructs an object of type `T`, with the given fields.
60+
Fields must be correct in name and type, and `T` must have a default constructor.
61+
62+
This internally is called to construct structs of the primal type `T`,
63+
after an operation such as the addition of a primal to a composite.
64+
65+
It should be overloaded, if `T` does not have a default constructor,
66+
or if `T` needs to maintain some invarients between its fields.
67+
"""
68+
function construct(::Type{T}, fields::NamedTuple{L}) where {T, L}
69+
# Tested and verified that that this avoids a ton of allocations
70+
if length(L) !== fieldcount(T)
71+
# if length is equal but names differ then we will catch that below anyway.
72+
throw(ArgumentError("Unmatched fields. Type: $(fieldnames(T)), NamedTuple: $L"))
73+
end
74+
75+
if @generated
76+
vals = (:(getproperty(fields, $(QuoteNode(fname)))) for fname in fieldnames(T))
77+
return :(T($(vals...)))
78+
else
79+
return T((getproperty(fields, fname) for fname in fieldnames(T))...)
80+
end
81+
end
82+
83+
construct(::Type{T}, fields::T) where T<:NamedTuple = fields
84+
construct(::Type{T}, fields::T) where T<:Tuple = fields
85+
86+
elementwise_add(a::Tuple, b::Tuple) = map(+, a, b)
87+
88+
function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn}
89+
# Rule of Composite addition: any fields not present are implict hard Zeros
90+
91+
# Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base.
92+
# https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231
93+
if @generated
94+
names = Base.merge_names(an, bn)
95+
types = Base.merge_types(names, a, b)
96+
97+
vals = map(names) do field
98+
a_field = :(getproperty(a, $(QuoteNode(field))))
99+
b_field = :(getproperty(b, $(QuoteNode(field))))
100+
val_expr = if Base.sym_in(field, an)
101+
if Base.sym_in(field, bn)
102+
# in both
103+
:($a_field + $b_field)
104+
else
105+
# only in `an`
106+
a_field
107+
end
108+
else # must be in `b` only
109+
b_field
110+
end
111+
end
112+
return :(NamedTuple{$names, $types}(($(vals...),)))
113+
else
114+
names = Base.merge_names(an, bn)
115+
types = Base.merge_types(names, typeof(a), typeof(b))
116+
vals = map(names) do field
117+
if Base.sym_in(field, an)
118+
a_field = getproperty(a, field)
119+
if Base.sym_in(field, bn)
120+
# in both
121+
b_field = getproperty(b, field)
122+
a_field + b_field
123+
else
124+
# only in `an`
125+
a_field
126+
end
127+
else # must be in `b` only
128+
getproperty(b, field)
129+
end
130+
end
131+
return NamedTuple{names,types}(vals)
132+
end
133+
end

src/differential_arithmetic.jl

Lines changed: 15 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ subtypes, as we know the full set that might be encountered.
77
Thus we can avoid any ambiguities.
88
99
Notice:
10-
The precidence goes: (:Wirtinger, :Zero, :DoesNotExist, :One, :AbstractThunk, :Any)
10+
The precedence goes:
11+
`Wirtinger, Zero, DoesNotExist, One, AbstractThunk, Composite, Any`
1112
Thus each of the @eval loops creating definitions of + and *
1213
defines the combination this type with all types of lower precidence.
1314
This means each eval loops is 1 item smaller than the previous.
@@ -93,69 +94,19 @@ end
9394
# We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful
9495
# In general one doesn't have to represent multiplications of 2 differentials
9596
# Only of a differential and a scaling factor (generally `Real`)
96-
Base.*(s::Any, comp::Composite) = map(x->s*x, comp)
97-
Base.*(comp::Composite, s::Any) = s*comp
98-
99-
function Base.:+(a::Composite{Primal, NamedTuple{an}}, b::Composite{Primal, NamedTuple{bn}}) where Primal
100-
# Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base.
101-
# https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231
102-
if @generated
103-
names = Base.merge_names(an, bn)
104-
types = Base.merge_types(names, a, b)
105-
106-
vals = map(names) do field
107-
a_field = :(getproperty(:a, $(QuoteNode(field))))
108-
b_field = :(getproperty(:b, $(QuoteNode(field))))
109-
val_expr = if Base.sym_in(field, an)
110-
if Base.sym_in(field, bn)
111-
# in both
112-
:($a_field + $b_field)
113-
else
114-
# only in `an`
115-
a_field
116-
end
117-
else # must be in `b` only
118-
b_field
119-
end
120-
end
121-
return :(NamedTuple{$names, $types}(($(vals...),)))
122-
else
123-
names = Base.merge_names(an, bn)
124-
types = Base.merge_types(names, typeof(a), typeof(b))
125-
vals = map(names) do field
126-
val_expr = if Base.sym_in(field, an)
127-
a_field = getproperty(a, field)
128-
if Base.sym_in(field, bn)
129-
# in both
130-
b_field = getproperty(a, field)
131-
:($a_field + $b_field)
132-
else
133-
# only in `an`
134-
a_field
135-
end
136-
else # must be in `b` only
137-
b_field = getproperty(a, field)
138-
b_field
139-
end
140-
end
141-
NamedTuple{names,types}(map(n->getfield(sym_in(n, bn) ? b : a, n), names))
142-
end
143-
end
144-
end
97+
Base.:*(s::Any, comp::Composite) = map(x->s*x, comp)
98+
Base.:*(comp::Composite, s::Any) = map(x->x*s, comp)
14599

146-
# this should not need to be generated, # TODO test that
147-
function Base.:+(a::Composite{Primal, <:Tuple}, b::Composite{Primal, <:Tuple}) where Primal
148-
# TODO: should we even allow it on different lengths?
149-
short, long = length(a) < length(b) ? (a.backing, b.backing) : (b.backing, a.backing)
150-
backing = ntuple(length(long)) do ii
151-
long_val = getfield(long, ii)
152-
if ii <= length(short)
153-
short_val = getfield(short, ii)
154-
return short_val + long_val
155-
else
156-
return long_val
157-
end
158-
end
159100

160-
return Composite{Primal, typeof(backing)}(backing)
101+
function Base.:+(a::Composite{P}, b::Composite{P}) where P
102+
data = elementwise_add(backing(a), backing(b))
103+
return Composite{P, typeof(data)}(data)
104+
end
105+
function Base.:+(a::P, d::Composite{P}) where P
106+
try
107+
return construct(P, elementwise_add(backing(a), backing(d)))
108+
catch err
109+
throw(PrimalAdditionFailedException(a, d, err))
110+
end
161111
end
112+
Base.:+(a::Composite{P}, b::P) where P = b + a

src/differentials.jl

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ Base.iterate(::One, ::Any) = nothing
149149
#####
150150
##### `AbstractThunk
151151
#####
152+
152153
abstract type AbstractThunk <: AbstractDifferential end
153154

154155
Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x))
@@ -237,8 +238,17 @@ macro thunk(body)
237238
return :(Thunk($(esc(func))))
238239
end
239240

241+
"""
242+
unthunk(x)
243+
244+
`unthunk` removes 1 layer of thunking from an `AbstractThunk`,
245+
and on all other types is the `identity` function.
246+
"""
247+
unthunk(x) = x
248+
unthunk(x::Thunk) = x()
249+
240250
# have to define this here after `@thunk` and `Thunk` is defined
241-
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))
251+
Base.conj(x::AbstractThunk) = @thunk(conj(unthunk(x)))
242252

243253
(x::Thunk)() = x.f()
244254
@inline unthunk(x::Thunk) = x()
@@ -286,61 +296,68 @@ const NO_FIELDS = DoesNotExist()
286296

287297

288298
"""
289-
Composite{Primal, T} <: AbstractDifferential
299+
Composite{P, T} <: AbstractDifferential
290300
291301
This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`.
292-
`Primal` is the the corresponding primal type that this is a differential for.
302+
`P` is the the corresponding primal type that this is a differential for.
293303
294-
`Composite{Primal}` should have fields (technically properties), that match to a subset of the
304+
`Composite{P}` should have fields (technically properties), that match to a subset of the
295305
fields of the primal type; and each should be a differential type matching to the primal
296306
type of that field.
297-
Fields of the Primal that are not present in the Composite are treated as `Zero`.
307+
Fields of the P that are not present in the Composite are treated as `Zero`.
298308
299-
`T` is an implementation detail representing the backing datastructure.
309+
`T` is an implementation detail representing the backing data structure.
300310
For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
301311
It should not be passed in by user.
302312
"""
303-
struct Composite{Primal, T} <: AbstractDifferential
313+
struct Composite{P, T} <: AbstractDifferential
314+
# Note: If T is a Tuple, then P is also a Tuple
315+
# (but potentially a different one, as it doesn't contain differentials)
304316
backing::T
305317
end
306318

307-
308-
function Composite{Primal}(;kwargs...) where Primal
309-
backing = (; kwargs...)
310-
return Composite{Primal, typeof(backing)}(backing)
319+
function Composite{P}(; kwargs...) where P
320+
backing = (; kwargs...) # construct as NamedTuple
321+
return Composite{P, typeof(backing)}(backing)
311322
end
312323

313-
function Composite{Primal}(args...) where Primal
314-
return Composite{Primal, typeof(args)}(args)
324+
function Composite{P}(args...) where P
325+
return Composite{P, typeof(args)}(args)
315326
end
316327

317-
function Base.show(io::IO, comp::Composite{Primal})
328+
function Base.show(io::IO, comp::Composite{P}) where P
318329
print(io, "Composite{")
319-
show(io, Primal)
330+
show(io, P)
320331
print(io, "}")
321332
# allow Tuple or NamedTuple `show` to do the rendering of brackets etc
322-
show(io, comp.backing)
333+
show(io, backing(comp))
323334
end
324335

325-
#TODO think about this, for if we are missing fields
326-
#Base.convert(::Type{Primal}, comp::Composite{Primal})
327-
Base.convert(::Type{<:NamedTuple}, comp::Composite{<:Any, <:NamedTuple}) = comp.backing
328-
Base.convert(::Type{<:Tuple}, comp::Composite{<:Any, <:Tuple}) = comp.backing
336+
Base.convert(::Type{<:NamedTuple}, comp::Composite{<:Any, <:NamedTuple}) = backing(comp)
337+
Base.convert(::Type{<:Tuple}, comp::Composite{<:Any, <:Tuple}) = backing(comp)
338+
339+
Base.getindex(comp::Composite, idx) = getindex(backing(comp), idx)
340+
Base.getproperty(comp::Composite, idx::Int) = getproperty(backing(comp), idx) # for Tuple
341+
Base.getproperty(comp::Composite, idx::Symbol) = getproperty(backing(comp), idx)
342+
Base.propertynames(comp::Composite) = propertynames(backing(comp))
329343

330-
Base.getindex(comp::Composite, idx) = getindex(comp.backing)
331-
Base.getproperty(comp::Composite, idx) = getproperty(comp.backing, idx)
332-
Base.propertynames(comp::Composite) = propertynames(comp.backing)
333-
Base.iterate(comp::Compositem, args...) = iterate(comp.backing, args...)
334-
Base.length(comp::Composite) = length(comp.backing)
344+
Base.iterate(comp::Composite, args...) = iterate(backing(comp), args...)
345+
Base.length(comp::Composite) = length(backing(comp))
346+
Base.eltype(::Type{<:Composite{<:Any, T}}) where T = eltype(T)
335347

336-
map(f, comp::Composite{Primal, <:Tuple}) where Primal = Composite{Primal}(map(f, comp.backing))
337-
function map(f, comp::Composite{Primal, <:NamedTuple{L}}) where{Primal, L}
338-
vals = map(f, Tuple(comp.backing))
348+
function Base.map(f, comp::Composite{P, <:Tuple}) where P
349+
vals::Tuple = map(f, backing(comp))
350+
return Composite{P, typeof(vals)}(vals)
351+
end
352+
function Base.map(f, comp::Composite{P, <:NamedTuple{L}}) where{P, L}
353+
vals = map(f, Tuple(backing(comp)))
339354
named_vals = NamedTuple{L, typeof(vals)}(vals)
340-
return Composite{Primal}(named_vals)
355+
return Composite{P, typeof(named_vals)}(named_vals)
341356
end
342357

343-
Base.conj(comp::Composite{Primal}) = map(conj, comp)
358+
Base.conj(comp::Composite) = map(conj, comp)
359+
360+
extern(comp::Composite) = backing(map(extern, comp)) # gives a NamedTuple or Tuple
344361

345362
#==============================================================================#
346363

0 commit comments

Comments
 (0)