Skip to content

Commit 72dad7f

Browse files
authored
rename differential types (#153)
* replace DoesNotExist * replace Composite * replace Zero * bump version and compat
1 parent faf4793 commit 72dad7f

File tree

10 files changed

+93
-93
lines changed

10 files changed

+93
-93
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.11"
3+
version = "0.6.12"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.9.39"
14+
ChainRulesCore = "0.9.44"
1515
Compat = "3"
1616
FiniteDifferences = "0.12"
1717
julia = "1"

docs/src/index.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using ChainRulesCore
2626
2727
function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
2828
y = two2three(x1, x2)
29-
∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2)
29+
∂y = Tangent{Tuple{Float64, Float64, Float64}}(ZeroTangent(), 2.0*Δx1, 3.0*Δx2)
3030
return y, ∂y
3131
end
3232
# output
@@ -117,11 +117,11 @@ Test.DefaultTestSet("test_scalar: relu at -0.5", Any[Test.DefaultTestSet("with t
117117
[`test_frule`](@ref) and [`test_rrule`](@ref) allow you to specify the tangents used for testing.
118118
This is done by passing in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
119119
If this is not done the tangent will be automatically generated via `FiniteDifferences.rand_tangent`.
120-
A special case of this is that if you specify it as `x ⊢ DoesNotExist()` then finite differencing will not be used on that input.
120+
A special case of this is that if you specify it as `x ⊢ NoTangent()` then finite differencing will not be used on that input.
121121
Similarly, by setting the `output_tangent` keyword argument, you can specify the tangent for the primal output.
122122

123123
This can be useful when the default provided `FiniteDifferences.rand_tangent` doesn't produce the desired tangent for your type.
124-
For example the default tangent for an `Int` is `DoesNotExist()`.
124+
For example the default tangent for an `Int` is `NoTangent()`.
125125
Which is correct e.g. when the `Int` represents a discrete integer like in indexing.
126126
But if you are testing something where the `Int` is actually a special case of a real number, then you would want to specify the tangent as a `Float64`.
127127

@@ -134,7 +134,7 @@ Inserting inappropriate zeros can thus hide errors.
134134

135135
If a package is using a custom finite differencing method of testing the `frule`s and `rrule`s, `check_equal` function provides a convenient way of comparing [various types](https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html#Design-Notes:-The-many-to-many-relationship-between-differential-types-and-primal-types.) of differentials.
136136

137-
It is effectively `(a, b) -> @test isapprox(a, b)`, but it preprocesses `thunk`s and `ChainRules` differential types `Zero()`, `DoesNotExist()`, and `Composite`, such that the error messages are helpful.
137+
It is effectively `(a, b) -> @test isapprox(a, b)`, but it preprocesses `thunk`s and `ChainRules` differential types `ZeroTangent()`, `NoTangent()`, and `Tangent`, such that the error messages are helpful.
138138

139139
For example,
140140
```julia

src/check_result.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, Abs
2525
end
2626
end
2727

28-
check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...)
29-
check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...)
30-
check_equal(x::Zero, y::Zero; kwargs...) = @test true
28+
check_equal(::ZeroTangent, x; kwargs...) = check_equal(zero(x), x; kwargs...)
29+
check_equal(x, ::ZeroTangent; kwargs...) = check_equal(x, zero(x); kwargs...)
30+
check_equal(x::ZeroTangent, y::ZeroTangent; kwargs...) = @test true
3131

3232
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
33-
check_equal(x::DoesNotExist, y::Nothing; kwargs...) = @test true
34-
check_equal(x::Nothing, y::DoesNotExist; kwargs...) = @test true
33+
check_equal(x::NoTangent, y::Nothing; kwargs...) = @test true
34+
check_equal(x::Nothing, y::NoTangent; kwargs...) = @test true
3535

3636
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
3737
# not yet been implemented
@@ -75,7 +75,7 @@ function check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...)
7575
end
7676
end
7777

78-
function check_equal(actual::Composite{P}, expected::Composite{P}; kwargs...) where P
78+
function check_equal(actual::Tangent{P}, expected::Tangent{P}; kwargs...) where P
7979
if _can_pass_early(actual, expected)
8080
@test true
8181
else
@@ -87,7 +87,7 @@ function check_equal(actual::Composite{P}, expected::Composite{P}; kwargs...) wh
8787
end
8888

8989
function check_equal(
90-
::Composite{ActualPrimal}, expected::Composite{ExpectedPrimal}; kwargs...
90+
::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}; kwargs...
9191
) where {ActualPrimal, ExpectedPrimal}
9292
# this will certainly fail as we have another dispatch for that, but this will give as
9393
# good error message
@@ -96,26 +96,26 @@ end
9696

9797

9898
# Some structual differential and a natural differential
99-
function check_equal(actual::Composite{P, T}, expected; kwargs...) where {T, P}
99+
function check_equal(actual::Tangent{P, T}, expected; kwargs...) where {T, P}
100100
if _can_pass_early(actual, expected)
101101
@test true
102102
else
103103
@assert (T <: NamedTuple) # it should be a structual differential if we hit this
104104

105-
# We are only checking the properties that are in the Composite
105+
# We are only checking the properties that are in the Tangent
106106
# the natural differential is allowed to have other properties that we ignore
107107
@testset "$P.$ii" for ii in propertynames(actual)
108108
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
109109
end
110110
end
111111
end
112-
check_equal(x, y::Composite; kwargs...) = check_equal(y, x; kwargs...)
112+
check_equal(x, y::Tangent; kwargs...) = check_equal(y, x; kwargs...)
113113

114-
# This catches comparisons of Composites and Tuples/NamedTuple
114+
# This catches comparisons of Tangents and Tuples/NamedTuple
115115
# and gives an error message complaining about that
116116
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
117-
check_equal(::C, ::T; kwargs...) where {C<:Composite,T<:LegacyZygoteCompTypes} = @test C === T
118-
check_equal(::T, ::C; kwargs...) where {C<:Composite,T<:LegacyZygoteCompTypes} = @test T === C
117+
check_equal(::C, ::T; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test C === T
118+
check_equal(::T, ::C; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test T === C
119119

120120
# Generic fallback, probably a tuple or something
121121
function check_equal(actual::A, expected::E; kwargs...) where {A, E}

src/deprecated.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,30 @@
33

44
Base.isapprox(a, b::Union{AbstractZero, AbstractThunk}; kwargs...) = isapprox(b, a; kwargs...)
55
Base.isapprox(d_ad::AbstractThunk, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...)
6-
Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `DoesNotExist`")
7-
# Call `all` to handle the case where `Zero` is standing in for a non-scalar zero
8-
Base.isapprox(d_ad::Zero, d_fd; kwargs...) = all(isapprox.(extern(d_ad), d_fd; kwargs...))
6+
Base.isapprox(d_ad::NoTangent, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `NoTangent`")
7+
# Call `all` to handle the case where `ZeroTangent` is standing in for a non-scalar zero
8+
Base.isapprox(d_ad::ZeroTangent, d_fd; kwargs...) = all(isapprox.(extern(d_ad), d_fd; kwargs...))
99

1010
isapprox_vec(a, b; kwargs...) = isapprox(first(to_vec(a)), first(to_vec(b)); kwargs...)
11-
Base.isapprox(a, b::Composite; kwargs...) = isapprox(b, a; kwargs...)
12-
function Base.isapprox(d_ad::Composite{<:Tuple}, d_fd::Tuple; kwargs...)
11+
Base.isapprox(a, b::Tangent; kwargs...) = isapprox(b, a; kwargs...)
12+
function Base.isapprox(d_ad::Tangent{<:Tuple}, d_fd::Tuple; kwargs...)
1313
return isapprox_vec(d_ad, d_fd; kwargs...)
1414
end
1515
function Base.isapprox(
16-
d_ad::Composite{P, <:Tuple}, d_fd::Composite{P, <:Tuple}; kwargs...
16+
d_ad::Tangent{P, <:Tuple}, d_fd::Tangent{P, <:Tuple}; kwargs...
1717
) where {P <: Tuple}
1818
return isapprox_vec(d_ad, d_fd; kwargs...)
1919
end
2020

2121
function Base.isapprox(
22-
d_ad::Composite{P, <:NamedTuple{T}}, d_fd::Composite{P, <:NamedTuple{T}}; kwargs...,
22+
d_ad::Tangent{P, <:NamedTuple{T}}, d_fd::Tangent{P, <:NamedTuple{T}}; kwargs...,
2323
) where {P, T}
2424
return isapprox_vec(d_ad, d_fd; kwargs...)
2525
end
2626

2727

2828
# Must be for same primal
29-
Base.isapprox(d_ad::Composite{P}, d_fd::Composite{Q}; kwargs...) where {P, Q} = false
29+
Base.isapprox(d_ad::Tangent{P}, d_fd::Tangent{Q}; kwargs...) where {P, Q} = false
3030

3131

3232
# From when primal and tangent was passed as a tuple

src/finite_difference_calls.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ end
9292

9393
# TODO: remove after https://github.com/JuliaDiff/FiniteDifferences.jl/issues/97
9494
# For functions which return a tuple, FD returns a tuple to represent the differential. Tuple
95-
# is not a natural differential, because it doesn't overload +, so make it a Composite.
96-
_maybe_fix_to_composite(::P, x::Tuple) where {P} = Composite{P}(x...)
97-
_maybe_fix_to_composite(::P, x::NamedTuple) where {P} = Composite{P}(;x...)
95+
# is not a natural differential, because it doesn't overload +, so make it a Tangent.
96+
_maybe_fix_to_composite(::P, x::Tuple) where {P} = Tangent{P}(x...)
97+
_maybe_fix_to_composite(::P, x::NamedTuple) where {P} = Tangent{P}(;x...)
9898
_maybe_fix_to_composite(::Any, x) = x

src/iterator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function Base.hash(iter::TestIterator{<:Any,IT,IS}) where {IT,IS}
4444
end
4545

4646
# To make it a valid differential: needs at very least `zero` and `+`
47-
Base.zero(::Type{<:TestIterator}) = Zero()
47+
Base.zero(::Type{<:TestIterator}) = ZeroTangent()
4848
function Base.:+(iter1::TestIterator{T,IS,IE}, iter2::TestIterator{T,IS,IE}) where {T,IS,IE}
4949
return TestIterator{T,IS,IE}(map(+, iter1.data, iter2.data))
5050
end

src/testers.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
3030
test_frule(f, z Δx; rule_test_kwargs...)
3131
if z isa Complex
3232
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
33-
_, real_tangent = frule((Zero(), real(Δx)), f, z; fkwargs...)
34-
_, embedded_tangent = frule((Zero(), Δx), f, z; fkwargs...)
33+
_, real_tangent = frule((ZeroTangent(), real(Δx)), f, z; fkwargs...)
34+
_, embedded_tangent = frule((ZeroTangent(), Δx), f, z; fkwargs...)
3535
check_equal(real_tangent, embedded_tangent; isapprox_kwargs...)
3636
end
3737
end
@@ -75,7 +75,7 @@ end
7575
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
7676
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
7777
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
78-
Non-differentiable arguments, such as indices, should have `ẋ` set as `DoesNotExist()`.
78+
Non-differentiable arguments, such as indices, should have `ẋ` set as `NoTangent()`.
7979
8080
# Keyword Arguments
8181
- `output_tangent` tangent to test accumulation of derivatives against
@@ -114,11 +114,11 @@ function test_frule(
114114
check_equal(Ω_ad, Ω; isapprox_kwargs...)
115115

116116
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
117-
ẋs_is_ignored = isa.(ẋs, Union{Nothing, DoesNotExist})
117+
ẋs_is_ignored = isa.(ẋs, Union{Nothing, NoTangent})
118118
if any(ẋs .== nothing)
119119
Base.depwarn(
120120
"test_frule(f, k ⊢ nothing) is deprecated, use " *
121-
"test_frule(f, k ⊢ DoesNotExist()) instead for non-differentiable ks",
121+
"test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
122122
:test_frule
123123
)
124124
end
@@ -142,7 +142,7 @@ end
142142
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
143143
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
144144
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
145-
Non-differentiable arguments, such as indices, should have `x̄` set as `DoesNotExist()`.
145+
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
146146
147147
# Keyword Arguments
148148
- `output_tangent` the seed to propagate backward for testing (techncally a cotangent).
@@ -191,24 +191,24 @@ function test_rrule(
191191

192192
# Correctness testing via finite differencing.
193193
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
194-
x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing, DoesNotExist})
194+
x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing, NoTangent})
195195
if any(accumulated_x̄ .== nothing)
196196
Base.depwarn(
197197
"test_rrule(f, k ⊢ nothing) is deprecated, use " *
198-
"test_rrule(f, k ⊢ DoesNotExist()) instead for non-differentiable ks",
198+
"test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
199199
:test_rrule
200200
)
201201
end
202202

203203
x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
204204
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
205-
if accumulated_x̄ isa Union{Nothing, DoesNotExist} # then we marked this argument as not differentiable # TODO remove once #113
205+
if accumulated_x̄ isa Union{Nothing, NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
206206
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
207-
x̄_ad isa Zero && error(
208-
"The pullback in the rrule for $f function should use DoesNotExist()" *
209-
" rather than Zero() for non-perturbable arguments."
207+
x̄_ad isa ZeroTangent && error(
208+
"The pullback in the rrule for $f function should use NoTangent()" *
209+
" rather than ZeroTangent() for non-perturbable arguments."
210210
)
211-
@test x̄_ad isa DoesNotExist # we said it wasn't differentiable.
211+
@test x̄_ad isa NoTangent # we said it wasn't differentiable.
212212
else
213213
x̄_ad isa AbstractThunk && check_inferred && _test_inferred(unthunk, x̄_ad)
214214

test/check_result.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11

2-
struct FakeNaturalDiffWithIsApprox # For testing overloading isapprox(::Composite) works:
2+
struct FakeNaturalDiffWithIsApprox # For testing overloading isapprox(::Tangent) works:
33
x
44
end
5-
function Base.isapprox(c::Composite, d::FakeNaturalDiffWithIsApprox; kwargs...)
5+
function Base.isapprox(c::Tangent, d::FakeNaturalDiffWithIsApprox; kwargs...)
66
return isapprox(c.x, d.x, kwargs...)
77
end
8-
function Base.isapprox(d::FakeNaturalDiffWithIsApprox, c::Composite; kwargs...)
8+
function Base.isapprox(d::FakeNaturalDiffWithIsApprox, c::Tangent; kwargs...)
99
return isapprox(c.x, d.x, kwargs...)
1010
end
1111

@@ -15,7 +15,7 @@ end
1515
check = ChainRulesTestUtils._check_add!!_behaviour
1616

1717
check(10.0, 2.0)
18-
check(11.0, Zero())
18+
check(11.0, ZeroTangent())
1919
check([10.0, 20.0], @thunk([2.0, 0.0]))
2020

2121
check(12.0, InplaceableThunk(@thunk(2.0), X̄ -> error("Should not have in-placed")))
@@ -39,7 +39,7 @@ end
3939
check_equal(1.0, 1.0+1e-10) # isapprox _behaviour
4040
check_equal((1.5, 2.5, 3.5), (1.5, 2.5, 3.5 + 1e-10))
4141

42-
check_equal(Zero(), 0.0)
42+
check_equal(ZeroTangent(), 0.0)
4343

4444
check_equal([1.0, 2.0], [1.0, 2.0])
4545
check_equal([[1.0], [2.0]], [[1.0], [2.0]])
@@ -51,46 +51,46 @@ end
5151
check_equal(@not_implemented("a"), @not_implemented("a"))
5252

5353
check_equal(
54-
Composite{Tuple{Float64, Float64}}(1.0, 2.0),
55-
Composite{Tuple{Float64, Float64}}(1.0, 2.0)
54+
Tangent{Tuple{Float64, Float64}}(1.0, 2.0),
55+
Tangent{Tuple{Float64, Float64}}(1.0, 2.0)
5656
)
5757

5858
diag_eg = Diagonal(randn(5))
5959
check_equal( # Structual == Structural
60-
Composite{typeof(diag_eg)}(diag=diag_eg.diag),
61-
Composite{typeof(diag_eg)}(diag=diag_eg.diag)
60+
Tangent{typeof(diag_eg)}(diag=diag_eg.diag),
61+
Tangent{typeof(diag_eg)}(diag=diag_eg.diag)
6262
)
6363
check_equal( # Structural == Natural
64-
Composite{typeof(diag_eg)}(diag=diag_eg.diag),
64+
Tangent{typeof(diag_eg)}(diag=diag_eg.diag),
6565
diag_eg
6666
)
6767

6868
T = (a=1.0, b=2.0)
6969
check_equal(
70-
Composite{typeof(T)}(a=1.0),
71-
Composite{typeof(T)}(a=1.0, b=Zero())
70+
Tangent{typeof(T)}(a=1.0),
71+
Tangent{typeof(T)}(a=1.0, b=ZeroTangent())
7272
)
7373
check_equal(
74-
Composite{typeof(T)}(a=1.0),
75-
Composite{typeof(T)}(a=1.0+1e-10, b=Zero())
74+
Tangent{typeof(T)}(a=1.0),
75+
Tangent{typeof(T)}(a=1.0+1e-10, b=ZeroTangent())
7676
)
7777

7878
check_equal(
79-
Composite{FakeNaturalDiffWithIsApprox}(; x=1.4),
79+
Tangent{FakeNaturalDiffWithIsApprox}(; x=1.4),
8080
FakeNaturalDiffWithIsApprox(1.4)
8181
)
8282
check_equal(
8383
FakeNaturalDiffWithIsApprox(1.4),
84-
Composite{FakeNaturalDiffWithIsApprox}(; x=1.4)
84+
Tangent{FakeNaturalDiffWithIsApprox}(; x=1.4)
8585
)
8686
end
8787
@testset "negative case" begin
8888
@test fails(()->check_equal(1.0, 2.0))
8989
@test fails(()->check_equal(1.0 + im, 1.0 - im))
9090
@test fails(()->check_equal((1.5, 2.5, 3.5), (1.5, 2.5, 4.5)))
9191

92-
@test fails(()->check_equal(Zero(), 20.0))
93-
@test fails(()->check_equal(10.0, Zero()))
92+
@test fails(()->check_equal(ZeroTangent(), 20.0))
93+
@test fails(()->check_equal(10.0, ZeroTangent()))
9494

9595
@test fails(()->check_equal([1.0, 2.0], [1.0, 3.9]))
9696
@test fails(()->check_equal([[1.0], [2.0]], [[1.1], [2.0]]))
@@ -102,12 +102,12 @@ end
102102
@testset "type negative" begin
103103
@test fails() do # these have different primals so should not be equal
104104
check_equal(
105-
Composite{Tuple{Float32, Float32}}(1f0, 2f0),
106-
Composite{Tuple{Float64, Float64}}(1.0, 2.0)
105+
Tangent{Tuple{Float32, Float32}}(1f0, 2f0),
106+
Tangent{Tuple{Float64, Float64}}(1.0, 2.0)
107107
)
108108
end
109109
@test fails() do
110-
check_equal((1.0, 2.0), Composite{Tuple{Float64, Float64}}(1.0, 2.0))
110+
check_equal((1.0, 2.0), Tangent{Tuple{Float64, Float64}}(1.0, 2.0))
111111
end
112112
end
113113

0 commit comments

Comments
 (0)