Skip to content

Commit 6184811

Browse files
authored
Merge pull request #59 from JuliaDiff/ox/comp
Introduce Composite for structured deriviatives
2 parents 6c81169 + 3004b0c commit 6184811

19 files changed

+1025
-18
lines changed

src/ChainRulesCore.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
module ChainRulesCore
2-
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
2+
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
33

44
export frule, rrule
5-
export wirtinger_conjugate, wirtinger_primal, refine_differential
5+
export refine_differential, wirtinger_conjugate, wirtinger_primal
66
export @scalar_rule, @thunk
77
export extern, store!, unthunk
8-
export Wirtinger, Zero, One, DoesNotExist, Thunk, InplaceableThunk
8+
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Wirtinger, Zero
99
export NO_FIELDS
1010

11-
include("differentials.jl")
11+
include("compat.jl")
12+
13+
include("differentials/abstract_differential.jl")
14+
include("differentials/wirtinger.jl")
15+
include("differentials/zero.jl")
16+
include("differentials/does_not_exist.jl")
17+
include("differentials/one.jl")
18+
include("differentials/thunks.jl")
19+
include("differentials/composite.jl")
20+
1221
include("differential_arithmetic.jl")
22+
1323
include("operations.jl")
1424
include("rules.jl")
1525
include("rule_definition_tools.jl")

src/compat.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
if VERSION < v"1.2"
2+
Base.getproperty(x::Tuple, f::Int) = getfield(x, f)
3+
end

src/differential_arithmetic.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ subtypes, as we know the full set that might be encountered.
77
Thus we can avoid any ambiguities.
88
99
Notice:
10-
The precidence goes: (:Wirtinger, :Zero, :DoesNotExist, :One, :AbstractThunk, :Any)
10+
The precedence goes:
11+
`Wirtinger, Zero, DoesNotExist, One, AbstractThunk, Composite, Any`
1112
Thus each of the @eval loops creating definitions of + and *
1213
defines the combination this type with all types of lower precidence.
1314
This means each eval loops is 1 item smaller than the previous.
@@ -87,3 +88,25 @@ for T in (:Any,)
8788
@eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b
8889
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
8990
end
91+
92+
################## Composite ##############################################################
93+
94+
# We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful
95+
# In general one doesn't have to represent multiplications of 2 differentials
96+
# Only of a differential and a scaling factor (generally `Real`)
97+
Base.:*(s::Any, comp::Composite) = map(x->s*x, comp)
98+
Base.:*(comp::Composite, s::Any) = map(x->x*s, comp)
99+
100+
101+
function Base.:+(a::Composite{P}, b::Composite{P}) where P
102+
data = elementwise_add(backing(a), backing(b))
103+
return Composite{P, typeof(data)}(data)
104+
end
105+
function Base.:+(a::P, d::Composite{P}) where P
106+
try
107+
return construct(P, elementwise_add(backing(a), backing(d)))
108+
catch err
109+
throw(PrimalAdditionFailedException(a, d, err))
110+
end
111+
end
112+
Base.:+(a::Composite{P}, b::P) where P = b + a

src/differentials.jl

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ Base.iterate(::One, ::Any) = nothing
149149
#####
150150
##### `AbstractThunk
151151
#####
152+
152153
abstract type AbstractThunk <: AbstractDifferential end
153154

154155
Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x))
@@ -237,8 +238,17 @@ macro thunk(body)
237238
return :(Thunk($(esc(func))))
238239
end
239240

241+
"""
242+
unthunk(x)
243+
244+
`unthunk` removes 1 layer of thunking from an `AbstractThunk`,
245+
and on all other types is the `identity` function.
246+
"""
247+
unthunk(x) = x
248+
unthunk(x::Thunk) = x()
249+
240250
# have to define this here after `@thunk` and `Thunk` is defined
241-
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))
251+
Base.conj(x::AbstractThunk) = @thunk(conj(unthunk(x)))
242252

243253
(x::Thunk)() = x.f()
244254
@inline unthunk(x::Thunk) = x()
@@ -284,6 +294,73 @@ function itself, when that function is not a closure.
284294
"""
285295
const NO_FIELDS = DoesNotExist()
286296

297+
298+
"""
299+
Composite{P, T} <: AbstractDifferential
300+
301+
This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`.
302+
`P` is the the corresponding primal type that this is a differential for.
303+
304+
`Composite{P}` should have fields (technically properties), that match to a subset of the
305+
fields of the primal type; and each should be a differential type matching to the primal
306+
type of that field.
307+
Fields of the P that are not present in the Composite are treated as `Zero`.
308+
309+
`T` is an implementation detail representing the backing data structure.
310+
For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
311+
It should not be passed in by user.
312+
"""
313+
struct Composite{P, T} <: AbstractDifferential
314+
# Note: If T is a Tuple, then P is also a Tuple
315+
# (but potentially a different one, as it doesn't contain differentials)
316+
backing::T
317+
end
318+
319+
function Composite{P}(; kwargs...) where P
320+
backing = (; kwargs...) # construct as NamedTuple
321+
return Composite{P, typeof(backing)}(backing)
322+
end
323+
324+
function Composite{P}(args...) where P
325+
return Composite{P, typeof(args)}(args)
326+
end
327+
328+
function Base.show(io::IO, comp::Composite{P}) where P
329+
print(io, "Composite{")
330+
show(io, P)
331+
print(io, "}")
332+
# allow Tuple or NamedTuple `show` to do the rendering of brackets etc
333+
show(io, backing(comp))
334+
end
335+
336+
Base.convert(::Type{<:NamedTuple}, comp::Composite{<:Any, <:NamedTuple}) = backing(comp)
337+
Base.convert(::Type{<:Tuple}, comp::Composite{<:Any, <:Tuple}) = backing(comp)
338+
339+
Base.getindex(comp::Composite, idx) = getindex(backing(comp), idx)
340+
Base.getproperty(comp::Composite, idx::Int) = getproperty(backing(comp), idx) # for Tuple
341+
Base.getproperty(comp::Composite, idx::Symbol) = getproperty(backing(comp), idx)
342+
Base.propertynames(comp::Composite) = propertynames(backing(comp))
343+
344+
Base.iterate(comp::Composite, args...) = iterate(backing(comp), args...)
345+
Base.length(comp::Composite) = length(backing(comp))
346+
Base.eltype(::Type{<:Composite{<:Any, T}}) where T = eltype(T)
347+
348+
function Base.map(f, comp::Composite{P, <:Tuple}) where P
349+
vals::Tuple = map(f, backing(comp))
350+
return Composite{P, typeof(vals)}(vals)
351+
end
352+
function Base.map(f, comp::Composite{P, <:NamedTuple{L}}) where{P, L}
353+
vals = map(f, Tuple(backing(comp)))
354+
named_vals = NamedTuple{L, typeof(vals)}(vals)
355+
return Composite{P, typeof(named_vals)}(named_vals)
356+
end
357+
358+
Base.conj(comp::Composite) = map(conj, comp)
359+
360+
extern(comp::Composite) = backing(map(extern, comp)) # gives a NamedTuple or Tuple
361+
362+
#==============================================================================#
363+
287364
"""
288365
refine_differential(𝒟::Type, der)
289366
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#####
2+
##### `AbstractDifferential`
3+
#####
4+
5+
"""
6+
The subtypes of `AbstractDifferential` define a custom \"algebra\" for chain
7+
rule evaluation that attempts to factor various features like complex derivative
8+
support, broadcast fusion, zero-elision, etc. into nicely separated parts.
9+
10+
All subtypes of `AbstractDifferential` implement the following operations:
11+
12+
`+(a, b)`: linearly combine differential `a` and differential `b`
13+
14+
`*(a, b)`: multiply the differential `a` by the differential `b`
15+
16+
`Base.conj(x)`: complex conjugate of the differential `x`
17+
18+
`extern(x)`: convert `x` into an appropriate non-`AbstractDifferential` type for
19+
use outside of `ChainContext`.
20+
21+
Valid arguments to these operations are `T` where `T<:AbstractDifferential`, or
22+
where `T` has proper `+` and `*` implementations.
23+
24+
Additionally, all subtypes of `AbstractDifferential` support `Base.iterate` and
25+
`Base.Broadcast.broadcastable(x)`.
26+
"""
27+
abstract type AbstractDifferential end
28+
29+
Base.:+(x::AbstractDifferential) = x
30+
31+
"""
32+
extern(x)
33+
34+
Return `x` converted to an appropriate non-`AbstractDifferential` type, for use
35+
with external packages that might not handle `AbstractDifferential` types.
36+
37+
Note that this function may return an alias (not necessarily a copy) to data
38+
wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
39+
"""
40+
@inline extern(x) = x
41+
42+
@inline Base.conj(x::AbstractDifferential) = x
43+
44+
"""
45+
refine_differential(𝒟::Type, der)
46+
47+
Converts, if required, a differential object `der`
48+
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
49+
to another differential that is more suited for the domain given by the type 𝒟.
50+
Often this will behave as the identity function on `der`.
51+
"""
52+
refine_differential(::Any, der) = der # most of the time leave it alone.

0 commit comments

Comments
 (0)