Skip to content

Commit 3394de6

Browse files
authored
implement projection for InplaceableThunk (#541)
1 parent c203e94 commit 3394de6

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.12"
3+
version = "1.12.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ ProjectTo(::Any) = identity
122122

123123
# Thunks
124124
(project::ProjectTo)(dx::Thunk) = Thunk(project dx.f)
125+
(project::ProjectTo)(dx::InplaceableThunk) = project(dx.val)
125126

126127
# Zero
127128
ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pass makes this one projector,

test/projection.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,13 @@ struct NoSuperType end
427427
@test unthunk(pth) === 6.0 + 0.0im
428428
end
429429

430+
@testset "InplaceableThunk" begin
431+
it = InplaceableThunk(x -> x + 6, @thunk 1 + 2 + 3)
432+
pt = ProjectTo(4 + 5im)(it)
433+
@test pt isa Thunk
434+
@test unthunk(pt) === 6.0 + 0.0im
435+
end
436+
430437
@testset "Tangent" begin
431438
x = 1:3.0
432439
dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent())

0 commit comments

Comments
 (0)