Closed
Description
I have run into an error when trying to compute the gradient
or jacobian
using Zygote for a function that contains a cross product. I'm relatively new to using AD, so I would like to know if this is user error or something that needs to be fixed/added to the ChainRules package. The MWA and stacktrace is below:
using LinearAlgebra, Zygote
g(a, b) = hypot(cross(a, b)...)
h(a, b) = dot(cross(a, b))
jacobian(g, rand(3), rand(3))
The stacktrace:
ERROR: MethodError: no method matching cross(::Vector{Float64}, ::Tangent{Any, Tuple{Float64, Float64, Float64}})
Closest candidates are:
cross(::AbstractVector, ::AbstractVector) at ~/build/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:310
cross(::Any, ::AbstractThunk) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:89
cross(::AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:88
...
Stacktrace:
[1] (::ChainRules.var"#1949#1952"{Tangent{Any, Tuple{Float64, Float64, Float64}}, Vector{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}})()
@ ChainRules ~/.julia/packages/ChainRules/ajkp7/src/rulesets/LinearAlgebra/dense.jl:109
[2] unthunk
@ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
[3] wrap_chainrules_output
@ ~/.julia/packages/Zygote/tFaxC/src/compiler/chainrules.jl:105 [inlined]
[4] map
@ ./tuple.jl:223 [inlined]
[5] wrap_chainrules_output
@ ~/.julia/packages/Zygote/tFaxC/src/compiler/chainrules.jl:106 [inlined]
[6] (::Zygote.ZBack{ChainRules.var"#cross_pullback#1951"{Vector{Float64}, Vector{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}})(dy::Tuple{Float64, Float64, Float64})
@ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/chainrules.jl:206
[7] Pullback
@ ./REPL[76]:1 [inlined]
[8] (::typeof(∂(g)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface2.jl:0
[9] #208
@ ~/.julia/packages/Zygote/tFaxC/src/lib/lib.jl:206 [inlined]
[10] (::Zygote.var"#2066#back#210"{Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing}}, typeof(∂(g))}})(Δ::Float64)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[11] Pullback
@ ./operators.jl:1035 [inlined]
[12] (::typeof(∂(#_#95)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface2.jl:0
[13] (::Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof(∂(#_#95))})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/tFaxC/src/lib/lib.jl:206
[14] #2066#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[15] Pullback
@ ./operators.jl:1033 [inlined]
[16] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(g)}(Zygote._jvec, g))))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface2.jl:0
[17] (::Zygote.var"#60#61"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(g)}(Zygote._jvec, g)))})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface.jl:45
[18] withjacobian(::Function, ::Vector{Float64}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/tFaxC/src/lib/grad.jl:150
[19] jacobian(::Function, ::Vector{Float64}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/tFaxC/src/lib/grad.jl:128
[20] top-level scope
@ REPL[77]:1
A quick fix is if I extend the cross product:
LinearAlgebra.cross(a::Tangent, b::AbstractVector) = -LinearAlgebra.cross(b, a)
LinearAlgebra.cross(a::AbstractVector, b::Tangent) = [a[2]*b[3]-a[3]*b[2], a[3]*b[1]-a[1]*b[3], a[1]*b[2]-a[2]*b[1]]
Is this fix (a) needed or am I doing something incorrectly and (b) if a fix is needed, should it be in the ChainRules
package? Does the Tangent
need to be projected onto the subspace of the AbstractVector
?
Metadata
Metadata
Assignees
Labels
No labels