Skip to content

Commit 63fe051

Browse files
simeonschaubnickrobinson251
authored andcommitted
Add basic complex functions, fix 2-arg atan (JuliaDiff#64)
* add rules for complex util functions * fix accidentally deleted line * fix test cases
1 parent 52591a0 commit 63fe051

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

src/rulesets/Base/base.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,20 @@
8383

8484
@scalar_rule(hypot(x, y), (x / Ω, y / Ω))
8585
@scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx)
86-
@scalar_rule(atan(x, y), @setup(u = x^2 + y^2), (y / u, -x / u))
86+
@scalar_rule(atan(y, x), @setup(u = x^2 + y^2), (x / u, -y / u))
8787

8888
@scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt))
8989
@scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt))
9090
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)),
9191
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u))))
9292
@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16)),
9393
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
94+
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
95+
@scalar_rule(angle(x::Real), Zero())
96+
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))
97+
@scalar_rule(real(x::Real), One())
98+
@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2))
99+
@scalar_rule(imag(x::Real), Zero())
94100

95101
# product rule requires special care for arguments where `mul` is non-commutative
96102

test/rulesets/Base/base.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,19 @@
9999
end
100100
end
101101

102+
@testset "Unary complex functions" begin
103+
for x in (-6, rand.((Float32, Float64, Complex{Float32}, Complex{Float64}))...)
104+
rtol = x isa Complex{Float32} ? 1e-6 : 1e-9
105+
test_scalar(real, x; rtol=rtol)
106+
test_scalar(imag, x; rtol=rtol)
107+
# TODO: implement correct complex derivative
108+
x isa Real && test_scalar(abs, x; rtol=rtol)
109+
test_scalar(angle, x; rtol=rtol)
110+
test_scalar(abs2, x; rtol=rtol)
111+
test_scalar(conj, x; rtol=rtol)
112+
end
113+
end
114+
102115
@testset "*(x, y)" begin
103116
x, y = rand(3, 2), rand(2, 5)
104117
z, (dx, dy) = rrule(*, x, y)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Test
1212
# For testing purposes we use a lot of
1313
using ChainRulesCore: cast, extern, accumulate, accumulate!, store!, @scalar_rule,
1414
Wirtinger, wirtinger_primal, wirtinger_conjugate,
15-
Zero, One, Casted, DNE, Thunk, DNERule
15+
Zero, One, Casted, DNE, Thunk, DNERule, AbstractDifferential
1616

1717
include("test_util.jl")
1818

0 commit comments

Comments
 (0)