Skip to content

Commit 92ade11

Browse files
committed
handle unassigned a bit more
1 parent 1c91ed4 commit 92ade11

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

src/tangent_types/abstract_zero.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,9 @@ For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is d
107107
Exactly how it should be used (e.g. is it forward-mode only?)
108108
"""
109109
function zero_tangent end
110-
zero_tangent(::AbstractString) = ZeroTangent()
111-
# zero_tangent(::Number) = zero(x) # TODO: do we want this?
112-
zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this?
113-
zero_tangent(primal::Array) = map(zero_tangent, primal)
110+
111+
zero_tangent(x::Number) = zero(x)
112+
114113
@generated function zero_tangent(primal)
115114
has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples
116115
zfield_exprs = map(fieldnames(primal)) do fname
@@ -119,4 +118,23 @@ zero_tangent(primal::Array) = map(zero_tangent, primal)
119118
end
120119
backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...))
121120
return :($MutableTangent{$primal}($backing_expr))
122-
end
121+
end
122+
123+
function zero_tangent(x::Array{P, N}) where {P, N}
124+
(isbitstype(P) || all(i->isassigned(x,i), eachindex(x))) && return map(zero_tangent, x)
125+
126+
# Now we need to handle nonfully assigned arrays
127+
# see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265
128+
y = Array{guess_zero_tangent_type(P), N}(undef, size(x)...)
129+
@inbounds for n in eachindex(y)
130+
if isassigned(x, n)
131+
y[n] = zero_tangent(x[n])
132+
end
133+
end
134+
return y
135+
end
136+
137+
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
138+
guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = Array{guess_zero_tangent_type(T), N}
139+
guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634
140+
# TODO: we might be able to do better than this. even without.

test/tangent_types/abstract_zero.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,29 @@ end
172172

173173
@test zero_tangent([1.0, 2.0]) == [0.0, 0.0]
174174
@test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]]
175+
176+
@testset "undef elements" begin
177+
x = Vector{Vector{Float64}}(undef, 3)
178+
x[2] = [1.0,2.0]
179+
dx = zero_tangent(x)
180+
@test dx isa Vector{Vector{Float64}}
181+
@test length(dx) == 3
182+
@test !isassigned(dx, 1)
183+
@test dx[2] == [0.0, 0.0]
184+
@test !isassigned(dx, 3)
185+
186+
187+
a = Vector{MutDemo}(undef, 3)
188+
a[2] = MutDemo(1.5)
189+
da = zero_tangent(a)
190+
@test !isassigned(da, 1)
191+
@test iszero(da[2])
192+
@test !isassigned(da, 3)
193+
194+
195+
db = zero_tangent(Vector{MutDemo}(undef, 3))
196+
@test all(ii->!isassigned(db,ii), eachindex(db))
197+
@test length(db)==3
198+
@test db isa Vector
199+
end
175200
end

0 commit comments

Comments
 (0)