Skip to content

Commit 313f65f

Browse files
committed
Remove Wirtinger
1 parent 146d031 commit 313f65f

File tree

5 files changed

+15
-77
lines changed

5 files changed

+15
-77
lines changed

docs/src/index.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ The most important `AbstractDifferential`s when getting started are the ones abo
233233
- `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition.
234234

235235
#### Other `AbstractDifferential`s: don't worry about them right now
236-
- `Wirtinger`: it is complex. The docs need to be better. [Read the links in this issue](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/40).
237236
- `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10)
238237
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.
239238

src/rulesets/Base/base.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
@scalar_rule(zero(x), Zero())
33
@scalar_rule(sign(x), Zero())
44

5-
@scalar_rule(abs2(x), Wirtinger(x', x))
5+
@scalar_rule(abs2(x), 2x)
66
@scalar_rule(log(x), inv(x))
77
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))
88
@scalar_rule(log2(x), inv(x) / log(oftype(x, 2)))
@@ -50,14 +50,10 @@
5050
@scalar_rule(deg2rad(x), π / oftype(x, 180))
5151
@scalar_rule(rad2deg(x), oftype(x, 180) / π)
5252

53-
@scalar_rule(conj(x), Wirtinger(Zero(), One()))
54-
@scalar_rule(adjoint(x), Wirtinger(Zero(), One()))
5553
@scalar_rule(transpose(x), One())
5654

5755
@scalar_rule(abs(x::Real), sign(x))
58-
@scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
5956
@scalar_rule(hypot(x::Real), sign(x))
60-
@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
6157
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist()))
6258

6359
@scalar_rule(+(x), One())
@@ -98,11 +94,8 @@
9894
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
9995
@scalar_rule(fma(x, y, z), (y, x, One()))
10096
@scalar_rule(muladd(x, y, z), (y, x, One()))
101-
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
10297
@scalar_rule(angle(x::Real), Zero())
103-
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))
10498
@scalar_rule(real(x::Real), One())
105-
@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2))
10699
@scalar_rule(imag(x::Real), Zero())
107100

108101
# product rule requires special care for arguments where `mul` is non-commutative

test/rulesets/Base/base.jl

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
end # Trig
6363

6464
@testset "math" begin
65-
for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im)
65+
for x in (-0.1, 6.4)
6666
test_scalar(deg2rad, x)
6767
test_scalar(rad2deg, x)
6868

@@ -73,30 +73,6 @@
7373
test_scalar(exp10, x)
7474

7575
x isa Real && test_scalar(cbrt, x)
76-
if (x isa Real && x >= 0) || x isa Complex
77-
# this check is needed because these have discontinuities between
78-
# `-10 + im*eps()` and `-10 - im*eps()`
79-
should_test_wirtinger = imag(x) != 0 && real(x) < 0
80-
test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger)
81-
test_scalar(log, x; test_wirtinger=should_test_wirtinger)
82-
test_scalar(log2, x; test_wirtinger=should_test_wirtinger)
83-
test_scalar(log10, x; test_wirtinger=should_test_wirtinger)
84-
test_scalar(log1p, x; test_wirtinger=should_test_wirtinger)
85-
end
86-
end
87-
end
88-
89-
@testset "Unary complex functions" begin
90-
for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im)
91-
test_scalar(real, x)
92-
test_scalar(imag, x)
93-
94-
test_scalar(abs, x)
95-
test_scalar(hypot, x)
96-
97-
test_scalar(angle, x)
98-
test_scalar(abs2, x)
99-
test_scalar(conj, x)
10076
end
10177
end
10278

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using Test
1111

1212
# For testing purposes we use a lot of
1313
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
14-
Wirtinger, wirtinger_primal, wirtinger_conjugate,
1514
Zero, One, DoesNotExist, Thunk, AbstractDifferential
1615

1716
Random.seed!(1) # Set seed that all testsets should reset to.

test/test_util.jl

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ const _fdm = central_fdm(5, 1)
77

88

99
"""
10-
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), test_wirtinger=x isa Complex, kwargs...)
10+
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
1111
1212
Given a function `f` with scalar input an scalar output, perform finite differencing checks,
1313
at input point `x` to confirm that there are correct ChainRules provided.
@@ -16,49 +16,22 @@ at input point `x` to confirm that there are correct ChainRules provided.
1616
- `f`: Function for which the `frule` and `rrule` should be tested.
1717
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
1818
19-
- `test_wirtinger`: test whether the wirtinger derivative is correct, too
20-
21-
All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`.
19+
All keyword arguments except for `fdm` is passed to `isapprox`.
2220
"""
23-
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...)
21+
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
2422
ensure_not_running_on_functor(f, "test_scalar")
2523

26-
@testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule)
27-
res = rule(f, x)
28-
@test res !== nothing # Check the rule was defined
29-
fx, prop_rule = res
24+
r_res = rrule(f, x)
25+
f_res = frule(f, x, Zero(), 1)
26+
@test r_res !== f_res !== nothing # Check the rule was defined
27+
r_fx, prop_rule = r_res
28+
f_fx, f_∂x = f_res
29+
@testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in ((rrule, r_fx, prop_rule(1)), (frule, f_fx, f_∂x))
3030
@test fx == f(x) # Check we still get the normal value, right
3131

3232
if rule == rrule
33-
∂self, ∂x = prop_rule(1)
33+
∂self, ∂x = ∂x
3434
@test ∂self === NO_FIELDS
35-
else # rule == frule
36-
# Got to input extra first aguement for internals
37-
# But it is only a dummy since this is not a functor
38-
∂x = prop_rule(NamedTuple(), 1)
39-
end
40-
41-
42-
# Check that we get the derivative right:
43-
if !test_wirtinger
44-
@test isapprox(
45-
∂x, fdm(f, x);
46-
rtol=rtol, atol=atol, kwargs...
47-
)
48-
else
49-
# For complex arguments, also check if the wirtinger derivative is correct
50-
∂Re = fdm-> f(x + ϵ), 0)
51-
∂Im = fdm-> f(x + im*ϵ), 0)
52-
= 0.5(∂Re - im*∂Im)
53-
∂̅ = 0.5(∂Re + im*∂Im)
54-
@test isapprox(
55-
wirtinger_primal(∂x), ∂;
56-
rtol=rtol, atol=atol, kwargs...
57-
)
58-
@test isapprox(
59-
wirtinger_conjugate(∂x), ∂̅;
60-
rtol=rtol, atol=atol, kwargs...
61-
)
6235
end
6336
end
6437
end
@@ -92,7 +65,9 @@ end
9265
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
9366
ensure_not_running_on_functor(f, "frule_test")
9467
xs, ẋs = collect(zip(xẋs...))
95-
Ω, pushforward = ChainRules.frule(f, xs...)
68+
dself = Zero()
69+
Ω, dΩ_ad = ChainRules.frule(f, xs..., dself, ẋs...)
70+
dΩ_ad = dΩ_ad .+ 0
9671
@test f(xs...) == Ω
9772
dΩ_ad = pushforward(NamedTuple(), ẋs...)
9873

@@ -204,10 +179,6 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
204179
end
205180
end
206181

207-
function Base.isapprox(ad::Wirtinger, fd; kwargs...)
208-
error("Finite differencing with Wirtinger rules not implemented")
209-
end
210-
211182
function Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...)
212183
error("Tried to differentiate w.r.t. a `DoesNotExist`")
213184
end

0 commit comments

Comments
 (0)