Skip to content

Commit 7e7b441

Browse files
committed
handle unassigned a bit more
1 parent 235fbcd commit 7e7b441

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

src/tangent_types/abstract_zero.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,9 @@ For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is d
103103
Exactly how it should be used (e.g. is it forward-mode only?)
104104
"""
105105
function zero_tangent end
106-
zero_tangent(::AbstractString) = ZeroTangent()
107-
# zero_tangent(::Number) = zero(x) # TODO: do we want this?
108-
zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this?
109-
zero_tangent(primal::Array) = map(zero_tangent, primal)
106+
107+
zero_tangent(x::Number) = zero(x)
108+
110109
@generated function zero_tangent(primal)
111110
has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples
112111
zfield_exprs = map(fieldnames(primal)) do fname
@@ -115,4 +114,23 @@ zero_tangent(primal::Array) = map(zero_tangent, primal)
115114
end
116115
backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...))
117116
return :($MutableTangent{$primal}($backing_expr))
118-
end
117+
end
118+
119+
function zero_tangent(x::Array{P, N}) where {P, N}
120+
(isbitstype(P) || all(i->isassigned(x,i), eachindex(x))) && return map(zero_tangent, x)
121+
122+
# Now we need to handle nonfully assigned arrays
123+
# see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265
124+
y = Array{guess_zero_tangent_type(P), N}(undef, size(x)...)
125+
@inbounds for n in eachindex(y)
126+
if isassigned(x, n)
127+
y[n] = zero_tangent(x[n])
128+
end
129+
end
130+
return y
131+
end
132+
133+
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
134+
guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = Array{guess_zero_tangent_type(T), N}
135+
guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634
136+
# TODO: we might be able to do better than this. even without.

test/tangent_types/abstract_zero.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,29 @@ end
166166

167167
@test zero_tangent([1.0, 2.0]) == [0.0, 0.0]
168168
@test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]]
169+
170+
@testset "undef elements" begin
171+
x = Vector{Vector{Float64}}(undef, 3)
172+
x[2] = [1.0,2.0]
173+
dx = zero_tangent(x)
174+
@test dx isa Vector{Vector{Float64}}
175+
@test length(dx) == 3
176+
@test !isassigned(dx, 1)
177+
@test dx[2] == [0.0, 0.0]
178+
@test !isassigned(dx, 3)
179+
180+
181+
a = Vector{MutDemo}(undef, 3)
182+
a[2] = MutDemo(1.5)
183+
da = zero_tangent(a)
184+
@test !isassigned(da, 1)
185+
@test iszero(da[2])
186+
@test !isassigned(da, 3)
187+
188+
189+
db = zero_tangent(Vector{MutDemo}(undef, 3))
190+
@test all(ii->!isassigned(db,ii), eachindex(db))
191+
@test length(db)==3
192+
@test db isa Vector
193+
end
169194
end

0 commit comments

Comments
 (0)