Skip to content

Commit 409c8d0

Browse files
committed
give up on Any[NoTangent(), ...] case
1 parent 7c84d82 commit 409c8d0

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/projection.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,7 @@ function ProjectTo(xs::AbstractArray)
199199
end
200200

201201
function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
202-
# Trivial case:
203-
if all(el -> el isa AbstractZero, dx)
204-
return NoTangent()
205-
end
206-
# Now deal with shape. The rule is that we reshape to add or remove trivial dimensions
202+
# First deal with shape. The rule is that we reshape to add or remove trivial dimensions
207203
# like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
208204
dy = if axes(dx) == project.axes
209205
dx
@@ -226,6 +222,11 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
226222
return dz
227223
end
228224

225+
# Trivial case -- won't collapse Any[NoTangent(), NoTangent()]
226+
function (project::ProjectTo{AbstractArray})(dx::AbstractArray{<:AbstractZero})
227+
return NoTangent()
228+
end
229+
229230
# Row vectors aren't acceptable as gradients for 1-row matrices:
230231
function (project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
231232
return project(reshape(vec(dx), 1, :))

test/projection.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
7272

7373
# some bugs
7474
@test pvec3(fill(NoTangent(), 3)) === NoTangent() #410, was an array of such
75-
@test pvec3(Any[NoTangent(), NoTangent(), NoTangent()]) === NoTangent()
7675
@test ProjectTo([pi])([1]) isa Vector{Int} #423, was Irrational -> Bool -> NoTangent
7776
end
7877

0 commit comments

Comments
 (0)