Skip to content

Commit 88df0b4

Browse files
authored
Merge pull request #424 from mcabbott/projectbugs
Fix two bugs in `ProjectTo`
2 parents cfe3a81 + 9844bf7 commit 88df0b4

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.1.0"
3+
version = "1.2.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ end
179179

180180
# For arrays of numbers, just store one projector:
181181
function ProjectTo(x::AbstractArray{T}) where {T<:Number}
182-
element = ProjectTo(zero(T))
182+
element = T <: Irrational ? ProjectTo{Real}() : ProjectTo(zero(T))
183183
if element isa ProjectTo{<:AbstractZero}
184184
return ProjectTo{NoTangent}() # short-circuit if all elements project to zero
185185
else
@@ -222,6 +222,9 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
222222
return dz
223223
end
224224

225+
# Trivial case, this won't collapse Any[NoTangent(), NoTangent()] but that's OK.
226+
(project::ProjectTo{AbstractArray})(dx::AbstractArray{<:AbstractZero}) = NoTangent()
227+
225228
# Row vectors aren't acceptable as gradients for 1-row matrices:
226229
function (project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
227230
return project(reshape(vec(dx), 1, :))

test/projection.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
6969
@test prow(transpose([1, 2, 3 + 4.0im])) isa Matrix # row vectors may not pass through
7070
@test prow(adjoint([1, 2, 3 + 5im])) == [1 2 3 - 5im]
7171
@test prow(adjoint([1, 2, 3])) isa Matrix
72+
73+
# some bugs
74+
@test pvec3(fill(NoTangent(), 3)) === NoTangent() #410, was an array of such
75+
@test ProjectTo([pi])([1]) isa Vector{Int} #423, was Irrational -> Bool -> NoTangent
7276
end
7377

7478
@testset "Base: arrays of arrays, etc" begin
@@ -160,6 +164,9 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
160164
@test pvecmat(collect.(zs)) == zs
161165
@test pvecmat(collect.(zs)) isa LinearAlgebra.AdjOrTransAbsVec
162166
end
167+
168+
# issue #410
169+
@test padj([NoTangent() NoTangent() NoTangent()]) === NoTangent()
163170
end
164171

165172
@testset "LinearAlgebra: dense structured matrices" begin

0 commit comments

Comments
 (0)