diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index f921db29d..d9df07baa 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -126,24 +126,38 @@ end @generated function zero_tangent(primal) fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. - zfield_exprs = map(fieldnames(primal)) do fname - fval = :( - if isdefined(primal, $(QuoteNode(fname))) - zero_tangent(getfield(primal, $(QuoteNode(fname)))) - else - # This is going to be potentially bad, but that's what they get for not giving us a primal - # This will never me mutated inplace, rather it will alway be replaced with an actual value first - ZeroTangent() - end - ) - Expr(:kw, fname, fval) + + # easy case exit early, can't hold references, can't be a reference. + if isbitstype(primal) + zfield_exprs = map(fieldnames(primal)) do fname + fval = :(zero_tangent(getfield(primal, $(QuoteNode(fname))))) + Expr(:kw, fname, fval) + end + return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) end - return if has_mutable_tangent(primal) - any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype - # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent - fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype)) + + # hard case need to be prepared for references to this, or that are contained within this + quote + counts = $count_references(primal) + any_mask = $(Expr(:tuple, Expr(:parameters, map(fieldnames(primal), fieldtypes(primal)) do fname, ftype + # If it is is unassigned, or if it doesn't have a concrete type, or we have multiple reference to it + # then let it take any value for its tangent + fdef = :( + !isdefined(primal, $(QuoteNode(fname))) || + !isconcretetype($ftype) || + get(counts, $(QuoteNode(fname)), 0) > 1 + ) Expr(:kw, fname, fdef) - end + end...))) + + # Construct tangents + + # Go back and fill in tangents that were not ready + end + +## TODO rewrite below + has_mutable_tangent(primal) + any_mask = :($MutableTangent{$primal}( $(Expr(:tuple, Expr(:parameters, any_mask...))), $(Expr(:tuple, Expr(:parameters, zfield_exprs...))), @@ -171,6 +185,36 @@ function zero_tangent(x::Array{P,N}) where {P,N} return y end +############################################### +count_references(x) = count_references(IdDict{Any, Int}(), x) +function count_references!(counts::IdDict{Any, Int}, x) + isbits(x) && return counts # can't be a refernece and can't hold a reference + counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing + if counts[x] == 1 # Only recurse the first time + for ii in fieldcount(typeof(x)) + field = getfield(x, ii) + count_references!(counts, field) + end + end + return counts +end + +function count_references!(counts::IdDict{Any, Int}, x::Array) + counts[x] = get(counts, x, 0) + 1 # increment before recursing + isbitstype(eltype(x)) && return counts # no need to look inside, it can't hold references + if counts[x] == 1 # only recurse the first time + for ele in x + count_references!(counts, ele) + end + end + return counts +end + +count_references!(counts::IdDict{Any, Int}, ::DataType) = counts + +############################################### + + # Sad heauristic methods we need because of unassigned values guess_zero_tangent_type(::Type{T}) where {T<:Number} = T guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 245d9a29d..08cdaaf28 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -303,24 +303,24 @@ end lk = Link(1.5) lk.next = lk - @test_broken d = zero_tangent(lk) - @test_broken d.data == 0.0 - @test_broken d.next === d + d = zero_tangent(lk) + @test d.data == 0.0 + @test d.next === d struct CarryingArray x::Vector end ca = CarryingArray(Any[1.5]) push!(ca.x, ca) - @test_broken d_ca = zero_tangent(ca) - @test_broken d_ca[1] == 0.0 - @test_broken d_ca[2] === _ca + @test d_ca = zero_tangent(ca) + @test d_ca[1] == 0.0 + @test d_ca[2] === _ca # Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing xs = Any[1.5] push!(xs, xs) - @test_broken d_xs = zero_tangent(xs) - @test_broken d_xs[1] == 0.0 - @test_broken d_xs[2] == d_xs + @test d_xs = zero_tangent(xs) + @test d_xs[1] == 0.0 + @test d_xs[2] == d_xs end end