Skip to content

Commit 48a2e78

Browse files
authored
make extern(::Thunk) recursive (#48)
* make extern(::Thunk) recursive, and improve show(::Thunk) * bump version (breaking) * make this a unreleased version * mark nonbreaking * fix tabs
1 parent 355b830 commit 48a2e78

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.2.0"
3+
version = "0.2.1-DEV"
44

55
[compat]
66
julia = "^1.0"

src/differentials.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,24 @@ Base.iterate(::One, ::Any) = nothing
181181
Thunk(()->v)
182182
A thunk is a deferred computation.
183183
It wraps a zero argument closure that when invoked returns a differential.
184+
185+
Calling that thunk, calls the wrapped closure.
186+
`extern`ing thunks applies recursively, it also externs the differial that the closure returns.
187+
If you do not want that, then simply call the thunk
188+
189+
```
190+
julia> t = @thunk(@thunk(3))
191+
Thunk(var"##7#9"())
192+
193+
julia> extern(t)
194+
3
195+
196+
julia> t()
197+
Thunk(var"##8#10"())
198+
199+
julia> t()()
200+
3
201+
```
184202
"""
185203
struct Thunk{F} <: AbstractDifferential
186204
f::F
@@ -190,7 +208,8 @@ macro thunk(body)
190208
return :(Thunk(() -> $(esc(body))))
191209
end
192210

193-
@inline extern(x::Thunk) = x.f()
211+
(x::Thunk)() = x.f()
212+
@inline extern(x::Thunk) = extern(x())
194213

195214
Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x))
196215

@@ -206,3 +225,5 @@ end
206225
end
207226

208227
Base.conj(x::Thunk) = @thunk(conj(extern(x)))
228+
229+
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")

test/differentials.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,25 @@
4949
@test conj(o) == o
5050
end
5151

52+
@testset "Thunk" begin
53+
@test @thunk(3) isa Thunk
54+
55+
@testset "show" begin
56+
rep = repr(Thunk(rand))
57+
@test occursin(r"Thunk\(.*rand.*\)", rep)
58+
end
59+
60+
@testset "Externing" begin
61+
@test extern(@thunk(3)) == 3
62+
@test extern(@thunk(@thunk(3))) == 3
63+
end
64+
65+
@testset "calling thunks should call inner function" begin
66+
@test (@thunk(3))() == 3
67+
@test (@thunk(@thunk(3)))() isa Thunk
68+
end
69+
end
70+
5271
@testset "No ambiguities in $f" for f in (+, *)
5372
# We don't use `Test.detect_ambiguities` as we are only interested in
5473
# the +, and * operations. We also would catch any that are unrelated

0 commit comments

Comments
 (0)