-
-
Notifications
You must be signed in to change notification settings - Fork 217
Open
Description
I've encountered this in the wild, but here is a minimal breaking example:
julia> using Zygote
julia> using StatsBase
julia> gradient(w -> w.sum, AnalyticWeights([1.0, 2.0, 3.0]))
ERROR: BoundsError: attempt to access 0-element Vector{Any} at index []
Stacktrace:
[1] throw_boundserror(A::Vector{Any}, I::Tuple{})
@ Base ./abstractarray.jl:651
[2] checkbounds
@ ./abstractarray.jl:616 [inlined]
[3] _getindex
@ ./abstractarray.jl:1196 [inlined]
[4] getindex(::Vector{Any})
@ Base ./abstractarray.jl:1170
[5] (::Zygote.var"#back#222"{:sum, Zygote.Context, AnalyticWeights{Float64, Float64, Vector{Float64}}, Float64})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsu1Y/src/lib/lib.jl:233
[6] #1789#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[7] Pullback
@ ./Base.jl:33 [inlined]
[8] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[9] Pullback
@ ./REPL[74]:1 [inlined]
[10] (::Zygote.var"#50#51"{typeof(∂(#25))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsu1Y/src/compiler/interface.jl:41
[11] gradient(::Function, ::AnalyticWeights{Float64, Float64, Vector{Float64}}, ::Vararg{Any, N} where N)
@ Zygote ~/.julia/packages/Zygote/nsu1Y/src/compiler/interface.jl:76
[12] top-level scope
@ REPL[74]:1
(It's possible to get around it by using w -> sum(w)
(by adding an rrule for a constructor, see below), but the original example is gradient(std, rand(3), AnalyticWeights([1.0, 2.0, 3.0]))
which calls the weights under the hood)
function ChainRulesCore.rrule(::Type{StatsBase.AnalyticWeights}, values)
AnalyticWeights_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ)
AnalyticWeights_pullback(ȳ::Tangent) = (NoTangent(), ȳ.values)
AnalyticWeights_pullback(ȳ::AbstractThunk) = AnalyticWeights_pullback(unthunk(ȳ))
return AnalyticWeights(values), AnalyticWeights_pullback
end
ThummeTo
Metadata
Metadata
Assignees
Labels
No labels