Skip to content

Error differentiating composed cross product with Zygote #689

Closed
@benjaminfaber

Description

@benjaminfaber

I have run into an error when trying to compute the gradient or jacobian using Zygote for a function that contains a cross product. I'm relatively new to using AD, so I would like to know if this is user error or something that needs to be fixed/added to the ChainRules package. The MWA and stacktrace is below:

using LinearAlgebra, Zygote

g(a, b) = hypot(cross(a, b)...)
h(a, b) = dot(cross(a, b))

jacobian(g, rand(3), rand(3))

The stacktrace:

ERROR: MethodError: no method matching cross(::Vector{Float64}, ::Tangent{Any, Tuple{Float64, Float64, Float64}})
Closest candidates are:
  cross(::AbstractVector, ::AbstractVector) at ~/build/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:310
  cross(::Any, ::AbstractThunk) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:89
  cross(::AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:88
  ...
Stacktrace:
  [1] (::ChainRules.var"#1949#1952"{Tangent{Any, Tuple{Float64, Float64, Float64}}, Vector{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}})()
    @ ChainRules ~/.julia/packages/ChainRules/ajkp7/src/rulesets/LinearAlgebra/dense.jl:109
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
  [3] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/tFaxC/src/compiler/chainrules.jl:105 [inlined]
  [4] map
    @ ./tuple.jl:223 [inlined]
  [5] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/tFaxC/src/compiler/chainrules.jl:106 [inlined]
  [6] (::Zygote.ZBack{ChainRules.var"#cross_pullback#1951"{Vector{Float64}, Vector{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}})(dy::Tuple{Float64, Float64, Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/chainrules.jl:206
  [7] Pullback
    @ ./REPL[76]:1 [inlined]
  [8] (::typeof((g)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface2.jl:0
  [9] #208
    @ ~/.julia/packages/Zygote/tFaxC/src/lib/lib.jl:206 [inlined]
 [10] (::Zygote.var"#2066#back#210"{Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing}}, typeof((g))}})(Δ::Float64)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [11] Pullback
    @ ./operators.jl:1035 [inlined]
 [12] (::typeof((#_#95)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface2.jl:0
 [13] (::Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof((#_#95))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/lib/lib.jl:206
 [14] #2066#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [15] Pullback
    @ ./operators.jl:1033 [inlined]
 [16] (::typeof((ComposedFunction{typeof(Zygote._jvec), typeof(g)}(Zygote._jvec, g))))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#60#61"{typeof((ComposedFunction{typeof(Zygote._jvec), typeof(g)}(Zygote._jvec, g)))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/compiler/interface.jl:45
 [18] withjacobian(::Function, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/lib/grad.jl:150
 [19] jacobian(::Function, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/tFaxC/src/lib/grad.jl:128
 [20] top-level scope
    @ REPL[77]:1

A quick fix is if I extend the cross product:

LinearAlgebra.cross(a::Tangent, b::AbstractVector) = -LinearAlgebra.cross(b, a)
LinearAlgebra.cross(a::AbstractVector, b::Tangent) = [a[2]*b[3]-a[3]*b[2], a[3]*b[1]-a[1]*b[3], a[1]*b[2]-a[2]*b[1]]

Is this fix (a) needed or am I doing something incorrectly and (b) if a fix is needed, should it be in the ChainRules package? Does the Tangent need to be projected onto the subspace of the AbstractVector?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions