Skip to content

Commit 07dfe60

Browse files
authored
Collapse zeros for Tuple & Ref tangents (#565)
* collapse zeros for Tuple & Ref tangents * v1.15.4
1 parent 4ce6418 commit 07dfe60

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.15.3"
3+
version = "1.15.4"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,11 @@ end
283283
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx))))
284284
function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref)
285285
dy = project.x(dx[])
286-
return project_type(project)(; x=dy)
286+
if dy isa AbstractZero
287+
return NoTangent()
288+
else
289+
return project_type(project)(; x=dy)
290+
end
287291
end
288292
# Since this works like a zero-array in broadcasting, it should also accept a number:
289293
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx))
@@ -321,7 +325,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
321325
end
322326
# Here map will fail if the lengths don't match, but gives a much less helpful error:
323327
dy = map((f, x) -> f(x), project.elements, dx)
324-
return project_type(project)(dy...)
328+
if all(d -> d isa AbstractZero, dy)
329+
return NoTangent()
330+
else
331+
return project_type(project)(dy...)
332+
end
325333
end
326334
function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple)
327335
dy = _project_namedtuple(backing(project), dx)
@@ -370,7 +378,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray)
370378
end
371379
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray
372380
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
373-
return project_type(project)(dz...)
381+
if all(d -> d isa AbstractZero, dy)
382+
return NoTangent()
383+
else
384+
return project_type(project)(dz...)
385+
end
374386
end
375387

376388

test/projection.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,19 @@ struct NoSuperType end
143143

144144
@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
145145
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
146+
147+
@test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero
146148
end
147149

148150
@testset "Base: Tuple" begin
149151
pt1 = ProjectTo((1.0,))
150-
@test pt1((1 + im,)) == Tangent{Tuple{Float64}}(1.0,)
151-
@test pt1(pt1((1,))) == pt1(pt1((1,))) # accepts correct Tangent
152-
@test pt1(Tangent{Any}(1)) == pt1((1,)) # accepts Tangent{Any}
152+
@test @inferred(pt1((1 + im,))) == Tangent{Tuple{Float64}}(1.0,)
153+
@test @inferred(pt1(pt1((1,)))) == pt1(pt1((1,))) # accepts correct Tangent
154+
@test @inferred(pt1(Tangent{Any}(1))) == pt1((1,)) # accepts Tangent{Any}
153155
@test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector
154-
@test pt1(NoTangent()) === NoTangent()
155-
@test pt1(ZeroTangent()) === ZeroTangent()
156+
@test @inferred(pt1(NoTangent())) === NoTangent()
157+
@test @inferred(pt1(ZeroTangent())) === ZeroTangent()
158+
@test @inferred(pt1((NoTangent(),))) === NoTangent() # collapse all-zero
156159

157160
@test_throws Exception pt1([1, 2]) # DimensionMismatch, wrong length
158161
@test_throws Exception pt1([])

0 commit comments

Comments
 (0)