Skip to content

Commit 3a16dab

Browse files
Shashi GowdaYingboMa
andcommitted
WIP: adding more rules and fixing old ones (#80)
* fma * add a few NaNMath rules * add some SpecialFunctions rules * isint wasn't defined * fixes * optimize rule for tanh like in JuliaDiff/DiffRules.jl#4861e3 * Fix pow. Co-authored-by: YingboMa <mayingbo5@gmail.com> * Remove tan from NaNMath.jl * Fix `inv` and add `muladd` * Address code review comments * Test new base rules * Patch version bump Co-authored-by: Yingbo Ma <mayingbo5@gmail.com>
1 parent 90b08a4 commit 3a16dab

File tree

5 files changed

+71
-10
lines changed

5 files changed

+71
-10
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.2.3"
3+
version = "0.2.4"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
ChainRulesCore = "0.4"
1414
FiniteDifferences = "^0.7"
1515
Reexport = "0.2"
16-
Requires = "0.5.2"
16+
Requires = "0.5.2, 1"
1717
julia = "^1.0"
1818

1919
[extras]

src/rulesets/Base/base.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
@scalar_rule(sinh(x), cosh(x))
3838
@scalar_rule(cosh(x), sinh(x))
39-
@scalar_rule(tanh(x), sech(x)^2)
39+
@scalar_rule(tanh(x), 1-Ω^2)
4040
@scalar_rule(coth(x), -(csch(x)^2))
4141

4242
@scalar_rule(asinh(x), inv(sqrt(x^2 + 1)))
@@ -66,7 +66,7 @@
6666
@scalar_rule(-(x, y), (One(), -1))
6767
@scalar_rule(/(x, y), (inv(y), -(x / y / y)))
6868
@scalar_rule(\(x, y), (-(y / x / x), inv(x)))
69-
@scalar_rule(^(x, y), (y * x^(y - 1), Ω * log(x)))
69+
@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(x)))
7070

7171
@scalar_rule(inv(x), -Ω^2)
7272
@scalar_rule(sqrt(x), inv(2 * Ω))
@@ -92,10 +92,12 @@
9292

9393
@scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt))
9494
@scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt))
95-
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)),
95+
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
9696
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u))))
97-
@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16)),
97+
@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
9898
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
99+
@scalar_rule(fma(x, y, z), (y, x, One()))
100+
@scalar_rule(muladd(x, y, z), (y, x, One()))
99101
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
100102
@scalar_rule(angle(x::Real), Zero())
101103
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))

src/rulesets/packages/NaNMath.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module NaNMathGlue
22
using ChainRulesCore
33
using ..NaNMath
4+
using ..SpecialFunctions
45

56
@scalar_rule(NaNMath.sin(x), NaNMath.cos(x))
67
@scalar_rule(NaNMath.cos(x), -NaNMath.sin(x))
@@ -15,5 +16,11 @@ using ..NaNMath
1516
@scalar_rule(NaNMath.lgamma(x), SpecialFunctions.digamma(x))
1617
@scalar_rule(NaNMath.sqrt(x), inv(2 * Ω))
1718
@scalar_rule(NaNMath.pow(x, y), (y * NaNMath.pow(x, y - 1), Ω * NaNMath.log(x)))
19+
@scalar_rule(NaNMath.max(x, y),
20+
(ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())),
21+
ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero()))))
22+
@scalar_rule(NaNMath.min(x, y),
23+
(ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())),
24+
ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero()))))
1825

1926
end #module

src/rulesets/packages/SpecialFunctions.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,47 @@ using ..SpecialFunctions
2323
@scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
2424
@scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω))
2525

26+
# binary
27+
@scalar_rule(SpecialFunctions.besselj(ν, x),
28+
(NaN,
29+
(SpecialFunctions.besselj- 1, x) -
30+
SpecialFunctions.besselj+ 1, x)) / 2))
31+
32+
@scalar_rule(SpecialFunctions.besseli(ν, x),
33+
(NaN,
34+
(SpecialFunctions.besseli- 1, x) +
35+
SpecialFunctions.besseli+ 1, x)) / 2))
36+
@scalar_rule(SpecialFunctions.bessely(ν, x),
37+
(NaN,
38+
(SpecialFunctions.bessely- 1, x) -
39+
SpecialFunctions.bessely+ 1, x)) / 2))
40+
41+
@scalar_rule(SpecialFunctions.besselk(ν, x),
42+
(NaN,
43+
-(SpecialFunctions.besselk- 1, x) +
44+
SpecialFunctions.besselk+ 1, x)) / 2))
45+
46+
@scalar_rule(SpecialFunctions.hankelh1(ν, x),
47+
(NaN,
48+
(SpecialFunctions.hankelh1- 1, x) -
49+
SpecialFunctions.hankelh1+ 1, x)) / 2))
50+
@scalar_rule(SpecialFunctions.hankelh2(ν, x),
51+
(NaN,
52+
(SpecialFunctions.hankelh2- 1, x) -
53+
SpecialFunctions.hankelh2+ 1, x)) / 2))
54+
55+
@scalar_rule(SpecialFunctions.polygamma(m, x),
56+
(NaN, SpecialFunctions.polygamma(m + 1, x)))
57+
58+
# todo: setup for common expr
59+
@scalar_rule(SpecialFunctions.beta(a, b),
60+
*(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b)),
61+
Ω*(SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b))))
62+
63+
@scalar_rule(SpecialFunctions.lbeta(a, b),
64+
(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b),
65+
SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b)))
66+
2667
# Changes between SpecialFunctions 0.7 and 0.8
2768
if isdefined(SpecialFunctions, :lgamma)
2869
# actually is the absolute value of the logorithm of gamma

test/rulesets/Base/base.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@
137137
test_accumulation(rand(2, 5), dy)
138138
end
139139

140-
@testset "binary trig ($f)" for f in (hypot, atan)
140+
@testset "binary function ($f)" for f in (hypot, atan, mod, rem, ^)
141141
rng = MersenneTwister(123456)
142-
x, Δx, x̄ = 10randn(rng, 3)
143-
y, Δy, ȳ = randn(rng, 3)
144-
Δz = randn(rng)
142+
x, Δx, x̄ = 10rand(rng, 3)
143+
y, Δy, ȳ = rand(rng, 3)
144+
Δz = rand(rng)
145145

146146
frule_test(f, (x, Δx), (y, Δy))
147147
rrule_test(f, Δz, (x, x̄), (y, ȳ))
@@ -176,4 +176,15 @@
176176
@test extern(ẏ) == 0
177177
end
178178
end
179+
180+
@testset "trinary ($f)" for f in (muladd, fma)
181+
rng = MersenneTwister(123456)
182+
x, Δx, x̄ = 10randn(rng, 3)
183+
y, Δy, ȳ = randn(rng, 3)
184+
z, Δz, z̄ = randn(rng, 3)
185+
Δk = randn(rng)
186+
187+
frule_test(f, (x, Δx), (y, Δy), (z, Δz))
188+
rrule_test(f, Δk, (x, x̄), (y, ȳ), (z, z̄))
189+
end
179190
end

0 commit comments

Comments
 (0)