Skip to content

Commit 869f452

Browse files
authored
fix #685 (#688)
* fix #685 * comment * comment * InplaceableThunk * v1.25.2
1 parent 46f5d95 commit 869f452

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.25.1"
3+
version = "1.25.2"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ ProjectTo(::Any) = identity
128128
ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pass makes this one projector,
129129
(::ProjectTo{NoTangent})(dx) = NoTangent() # but this is the projection only for nonzero gradients,
130130
(::ProjectTo{NoTangent})(dx::AbstractZero) = dx # and this one solves an ambiguity.
131+
(::ProjectTo{NoTangent})(::InplaceableThunk) = NoTangent() # solves ambiguity, #685
132+
(::ProjectTo{NoTangent})(::Thunk) = NoTangent() # solves ambiguity, #685
131133

132134
# Also, any explicit construction with fields, where all fields project to zero, itself
133135
# projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
@@ -277,7 +279,7 @@ end
277279
# but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see
278280
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105
279281
function ProjectTo(x::Ref)
280-
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
282+
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
281283
return ProjectTo{Tangent{typeof(x)}}(; x=sub)
282284
end
283285
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx))))

test/projection.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct NoSuperType end
8080

8181
prow = ProjectTo([1im 2 3im])
8282
@test prow(transpose([1, 2, 3 + 4.0im])) == [1 2 3 + 4im]
83-
@test prow(transpose([1, 2, 3 + 4.0im])) isa Matrix # row vectors may not pass through
83+
@test prow(transpose([1, 2, 3 + 4.0im])) isa Matrix # row vectors may not pass through
8484
@test prow(adjoint([1, 2, 3 + 5im])) == [1 2 3 - 5im]
8585
@test prow(adjoint([1, 2, 3])) isa Matrix
8686

@@ -145,7 +145,7 @@ struct NoSuperType end
145145

146146
@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
147147
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
148-
148+
149149
@test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero
150150
end
151151

@@ -376,7 +376,7 @@ struct NoSuperType end
376376

377377
pvec3 = ProjectTo([1, 2, 3])
378378
@test axes(pvec3(OffsetArray(rand(3), 0:2))) == (1:3,)
379-
@test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test
379+
@test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test
380380
@test pvec3(OffsetArray(rand(3,1), 0:2, 0:0)) isa Vector
381381
end
382382

@@ -463,4 +463,12 @@ struct NoSuperType end
463463
psymm = ProjectTo(Symmetric(rand(10^3, 10^3)))
464464
@test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
465465
end
466+
467+
@testset "#685" begin
468+
@test ProjectTo(BitArray([0]))([1.0]) == NoTangent()
469+
@test ProjectTo(BitArray([0]))(@thunk [1.0]) == NoTangent()
470+
471+
it = InplaceableThunk(x -> x + [1], @thunk [1.0])
472+
@test ProjectTo(BitArray([0]))(it) == NoTangent()
473+
end
466474
end

0 commit comments

Comments
 (0)