Skip to content

Support non-standard scalars in test_scalar #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.5.2"
version = "0.5.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -14,4 +14,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ChainRulesCore = "0.9.1"
Compat = "3"
FiniteDifferences = "0.10"
Quaternions = "0.4"
julia = "1"

[extras]
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"

[targets]
test = ["Quaternions"]
58 changes: 26 additions & 32 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,61 +112,55 @@ at input point `z` to confirm that there are correct `frule` and `rrule`s provid

`fkwargs` are passed to `f` as keyword arguments.
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.

To use this tester for a scalar type `MyNumber <: AbstractNumber`,
`FiniteDifferences.to_vec(::MyNumber)` must be implemented.
"""
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify this code by defining a seperate method for:

Suggested change
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
function test_scalar(f, z::Real; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would simplify the frule test because we wouldn't need the basis, but if the output is non-real we still need the basis on the output for the rrule test. Adding a separate method would require us to maintain that code in two places.

_ensure_not_running_on_functor(f, "test_scalar")
# z = x + im * y
# Ω = u(x, y) + im * v(x, y)
Ω = f(z; fkwargs...)

vz, z_from_vec = to_vec(z)
# orthonormal tangent vectors
vz_basis = Diagonal(ones(eltype(vz), length(vz)))
Δzs = [z_from_vec(vz_basis[:, i]) for i in axes(vz_basis, 2)]

# test jacobian using forward mode
Δx = one(z)
@testset "$f at $z, with tangent $Δx" begin
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
frule_test(f, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if z isa Complex
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
@testset "$f at $z, with tangent $Δz" for (i, Δz) in enumerate(Δzs)
frule_test(f, (z, Δz); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if !isa(Δz, Real) && i == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be:

Suggested change
if !isa(Δz, Real) && i == 1
if !isa(Δz, Real) && length(Δzs) == 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, no. i == 1 when when the given tangent vector is purely real, even if it isn't a Real. And this test checks that using an actually Real tangent vector gives the same result.

# check that same tangent is produced for tangent real(one(z)) and one(z)
@test isapprox(
frule((Zero(), real(Δx)), f, z; fkwargs...)[2],
frule((Zero(), Δx), f, z; fkwargs...)[2],
frule((Zero(), real(Δz)), f, z; fkwargs...)[2],
frule((Zero(), Δz), f, z; fkwargs...)[2],
rtol=rtol,
atol=atol,
kwargs...,
)
end
end
if z isa Complex
Δy = one(z) * im
@testset "$f at $z, with tangent $Δy" begin
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
frule_test(f, (z, Δy); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
end
end

vΩ, Ω_from_vec = to_vec(Ω)
# orthonormal cotangent vectors
vΩ_basis = Diagonal(ones(eltype(vΩ), length(vΩ)))
ΔΩs = [Ω_from_vec(vΩ_basis[:, i]) for i in axes(vΩ_basis, 2)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move this out into a helper function basis_vectors


Δx = Δzs[1]
# test jacobian transpose using reverse mode
Δu = one(Ω)
@testset "$f at $z, with cotangent $Δu" begin
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
rrule_test(f, Δu, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if Ω isa Complex
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
@testset "$f at $z, with cotangent $ΔΩ" for (i, ΔΩ) in enumerate(ΔΩs)
rrule_test(f, ΔΩ, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if !isa(ΔΩ, Real) && i == 1
# check that same cotangent is produced for cotangent real(one(Ω)) and one(Ω)
back = rrule(f, z)[2]
@test isapprox(
extern(back(real(Δu))[2]),
extern(back(Δu)[2]),
extern(back(real(ΔΩ))[2]),
extern(back(ΔΩ)[2]),
rtol=rtol,
atol=atol,
kwargs...,
)
end
end
if Ω isa Complex
Δv = one(Ω) * im
@testset "$f at $z, with cotangent $Δv" begin
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
rrule_test(f, Δv, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
end
end
end

"""
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using ChainRulesCore
using ChainRulesTestUtils
using FiniteDifferences
using LinearAlgebra
using Quaternions
using Random
using Test

Expand Down
27 changes: 27 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ function ChainRulesCore.frule((_, Δiter), ::typeof(iterfun), iter)
return s, ∂s
end

quatfun(q::Quaternion) = Quaternion(q.v3, 2 * q.v1, 3 * q.s, 4 * q.v2)

@testset "testers.jl" begin
@testset "test_scalar" begin
double(x) = 2x
Expand Down Expand Up @@ -263,4 +265,29 @@ end
frule_test(iterfun, (x, ẋ))
rrule_test(iterfun, randn(), (x, x̄))
end

@testset "test quaternion non-standard scalar" begin
function FiniteDifferences.to_vec(q::Quaternion)
function Quaternion_from_vec(q_vec)
return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4])
end
return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec
end
Comment on lines +374 to +379
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move this to be defined in the package itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean define this in Quaternions or FiniteDifferences?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or ChainRulesTestUtils?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FiniteDifferences

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it makes sense to make Quaternions an optional dependency for FiniteDifferences. Since I am only defining this for the purpose of testing, I'm comfortable with being type-piratical but just in the test suite where it can't pollute the methods table for other users. Thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @oxinabox what do you think?


function ChainRulesCore.frule((_, Δq), ::typeof(quatfun), q)
∂q = Quaternion(Δq)
return quatfun(q), Quaternion(∂q.v3, 2 * ∂q.v1, 3 * ∂q.s, 4 * ∂q.v2)
end

function ChainRulesCore.rrule(::typeof(quatfun), q)
function quatfun_pullback(ΔΩ)
∂Ω = Quaternion(ΔΩ)
return (NO_FIELDS, Quaternion(3 * ∂Ω.v2, 2 * ∂Ω.v1, 4 * ∂Ω.v3, ∂Ω.s))
end
return quatfun(q), quatfun_pullback
end

q = quatrand()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we define rand_tangent(:: Quaternion) in this package also?
@willtebbutt do you have plans around further advancing rand_tangent ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not currently -- not sure that there's much to do beyond integrating it in with ChainRulesTestUtils in some way or another and continuing to add new methods where necessary.

test_scalar(quatfun, q)
end
end