Skip to content

Commit ce1618a

Browse files
committed
Add multivariate rule tests
1 parent 1749566 commit ce1618a

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

test/rulesets/Base/base.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,27 @@ end
6262
test_scalar(acscd, x -> -180 / π / abs(x) / sqrt(x^2 - 1), x)
6363
test_scalar(acotd, x -> -180 / π / (1 + x^2), x)
6464
end
65-
# TODO: atan2 sincos
65+
@testset "Multivariate" begin
66+
x, y = rand(2)
67+
ratan = atan(x, y) # https://en.wikipedia.org/wiki/Atan2
68+
u = x^2 + y^2
69+
datan = y/u - 2x/u
70+
r, df = frule(atan, x, y)
71+
@test r === ratan
72+
@test df(1, 2) === datan
73+
r, (df1, df2) = rrule(atan, x, y)
74+
@test r === ratan
75+
@test df1(1) + df2(2) === datan
76+
77+
rsincos = sincos(x)
78+
dsincos = cos(x) - 2sin(x)
79+
r, (df1, df2) = frule(sincos, x)
80+
@test r === rsincos
81+
@test df1(1) + df2(2) === dsincos
82+
r, df = rrule(sincos, x)
83+
@test r === rsincos
84+
@test df(1, 2) === dsincos
85+
end
6686
end
6787
@testset "Misc. Tests" begin
6888
@testset "*(x, y)" begin

0 commit comments

Comments
 (0)