Skip to content

Commit 752732e

Browse files
committed
much AbstractThunk
1 parent 48a0391 commit 752732e

File tree

2 files changed

+38
-50
lines changed

2 files changed

+38
-50
lines changed

src/differential_arithmetic.jl

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered.
77
Thus we can avoid any ambiguities.
88
99
Notice:
10-
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any)
10+
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
1111
Thus each of the @eval loops creating definitions of + and *
1212
defines the combination this type with all types of lower precidence.
1313
This means each eval loops is 1 item smaller than the previous.
@@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
3636
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
3737
end
3838

39-
for T in (:Casted, :Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any)
39+
for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
4040
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
4141
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
4242

@@ -47,7 +47,7 @@ end
4747

4848
Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
4949
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
50-
for T in (:Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any)
50+
for T in (:Zero, :DNE, :One, :AbstractThunk, :Any)
5151
@eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b))
5252
@eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value))
5353

@@ -58,7 +58,7 @@ end
5858

5959
Base.:+(::Zero, b::Zero) = Zero()
6060
Base.:*(::Zero, ::Zero) = Zero()
61-
for T in (:DNE, :One, :Thunk, :InplaceableThunk, :Any)
61+
for T in (:DNE, :One, :AbstractThunk, :Any)
6262
@eval Base.:+(::Zero, b::$T) = b
6363
@eval Base.:+(a::$T, ::Zero) = a
6464

@@ -69,7 +69,7 @@ end
6969

7070
Base.:+(::DNE, ::DNE) = DNE()
7171
Base.:*(::DNE, ::DNE) = DNE()
72-
for T in (:One, :Thunk, :InplaceableThunk, :Any)
72+
for T in (:One, :AbstractThunk, :Any)
7373
@eval Base.:+(::DNE, b::$T) = b
7474
@eval Base.:+(a::$T, ::DNE) = a
7575

@@ -80,7 +80,7 @@ end
8080

8181
Base.:+(a::One, b::One) = extern(a) + extern(b)
8282
Base.:*(::One, ::One) = One()
83-
for T in (:Thunk, :InplaceableThunk, :Any)
83+
for T in (:AbstractThunk, :Any)
8484
@eval Base.:+(a::One, b::$T) = extern(a) + b
8585
@eval Base.:+(a::$T, b::One) = a + extern(b)
8686

@@ -89,23 +89,12 @@ for T in (:Thunk, :InplaceableThunk, :Any)
8989
end
9090

9191

92-
Base.:+(a::Thunk, b::Thunk) = extern(a) + extern(b)
93-
Base.:*(a::Thunk, b::Thunk) = extern(a) * extern(b)
94-
for T in (:InplaceableThunk, :Any)
95-
@eval Base.:+(a::Thunk, b::$T) = extern(a) + b
96-
@eval Base.:+(a::$T, b::Thunk) = a + extern(b)
92+
Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b)
93+
Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b)
94+
for T in (:Any,)
95+
@eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b
96+
@eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b)
9797

98-
@eval Base.:*(a::Thunk, b::$T) = extern(a) * b
99-
@eval Base.:*(a::$T, b::Thunk) = a * extern(b)
100-
end
101-
102-
# InplaceableThunk acts just like Thunk
103-
Base.:+(a::InplaceableThunk, b::InplaceableThunk) = extern(a) + extern(b)
104-
Base.:*(a::InplaceableThunk, b::InplaceableThunk) = extern(a) * extern(b)
105-
for T in (:Any, )
106-
@eval Base.:+(a::InplaceableThunk, b::$T) = extern(a) + b
107-
@eval Base.:+(a::$T, b::InplaceableThunk) = a + extern(b)
108-
109-
@eval Base.:*(a::InplaceableThunk, b::$T) = extern(a) * b
110-
@eval Base.:*(a::$T, b::InplaceableThunk) = a * extern(b)
98+
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
99+
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
111100
end

src/differentials.jl

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,24 @@ Base.iterate(x::One) = (x, nothing)
173173
Base.iterate(::One, ::Any) = nothing
174174

175175

176+
#####
177+
##### `AbstractThunk
178+
#####
179+
abstract type AbstractThunk <: AbstractDifferential end
180+
181+
Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x))
182+
183+
@inline function Base.iterate(x::AbstractThunk)
184+
externed = extern(x)
185+
element, state = iterate(externed)
186+
return element, (externed, state)
187+
end
188+
189+
@inline function Base.iterate(::AbstractThunk, (externed, state))
190+
element, new_state = iterate(externed, state)
191+
return element, (externed, new_state)
192+
end
193+
176194
#####
177195
##### `Thunk`
178196
#####
@@ -218,31 +236,20 @@ itself a `Thunk`.
218236
If you got the expression from another `rrule` (or `frule`), you don't need to
219237
`@thunk` it since it will have been thunked if required, by the defining rule.
220238
"""
221-
struct Thunk{F} <: AbstractDifferential
239+
struct Thunk{F} <: AbstractThunk
222240
f::F
223241
end
224242

225243
macro thunk(body)
226244
return :(Thunk(() -> $(esc(body))))
227245
end
228246

229-
(x::Thunk)() = x.f()
230-
@inline extern(x::Thunk) = extern(x())
231-
232-
Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x))
247+
# have to define this here after `@thunk` and `Thunk` is defined
248+
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))
233249

234-
@inline function Base.iterate(x::Thunk)
235-
externed = extern(x)
236-
element, state = iterate(externed)
237-
return element, (externed, state)
238-
end
239250

240-
@inline function Base.iterate(::Thunk, (externed, state))
241-
element, new_state = iterate(externed, state)
242-
return element, (externed, new_state)
243-
end
244-
245-
Base.conj(x::Thunk) = @thunk(conj(extern(x)))
251+
(x::Thunk)() = x.f()
252+
@inline extern(x::Thunk) = extern(x())
246253

247254
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
248255

@@ -259,29 +266,21 @@ but it should do this more efficently than simply doing this directly.
259266
Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
260267
and destroy its inplacability.
261268
"""
262-
struct InplaceableThunk{T<:Thunk, F} <: AbstractDifferential
269+
struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
263270
val::T
264271
add!::F
265272
end
266273

267274
(x::InplaceableThunk)() = x.val()
268275
@inline extern(x::InplaceableThunk) = extern(x.val)
269276

270-
Base.Broadcast.broadcastable(x::InplaceableThunk) = broadcastable(x.val)
271-
272-
@inline function Base.iterate(x::InplaceableThunk, args...)
273-
return iterate(x.val, args...)
274-
end
275-
276-
Base.conj(x::InplaceableThunk) = conj(x.val)
277-
278277
function Base.show(io::IO, x::InplaceableThunk)
279278
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
280279
end
281280

282281
# The real reason we have this:
283282
accumulate!(Δ, ∂::InplaceableThunk) =.add!(Δ)
284-
store!(Δ, ∂::InplaceableThunk) =.add!((Δ.*=false)) # zero i, then add to it.
283+
store!(Δ, ∂::InplaceableThunk) =.add!((Δ.*=false)) # zero it, then add to it.
285284

286285
"""
287286
NO_FIELDS

0 commit comments

Comments
 (0)