Skip to content

Commit 115adfa

Browse files
committed
wip
1 parent 9e4cb76 commit 115adfa

File tree

5 files changed

+61
-7
lines changed

5 files changed

+61
-7
lines changed

src/rulesets/Base/base.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
@scalar_rule(one(x), Zero())
22
@scalar_rule(zero(x), Zero())
3+
@scalar_rule(sign(x), Zero())
4+
35
@scalar_rule(abs2(x), Wirtinger(x', x))
46
@scalar_rule(log(x), inv(x))
57
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))

src/rulesets/packages/SpecialFunctions.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ using ChainRulesCore
33
using ..SpecialFunctions
44

55

6-
@scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x))
76
@scalar_rule(SpecialFunctions.erf(x), (2 / sqrt(π)) * exp(-x * x))
87
@scalar_rule(SpecialFunctions.erfc(x), -(2 / sqrt(π)) * exp(-x * x))
98
@scalar_rule(SpecialFunctions.erfi(x), (2 / sqrt(π)) * exp(x * x))
@@ -24,4 +23,19 @@ using ..SpecialFunctions
2423
@scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
2524
@scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω))
2625

26+
# Changes between SpecialFunctions 0.7 and 0.8
27+
if isdefined(SpecialFunctions, :lgamma)
28+
# actually is the absolute value of the logorithm of gamma
29+
@scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x))
30+
end
31+
32+
if isdefined(SpecialFunctions, :logabsgamma)
33+
# actually is the absolute value of the logorithm of gamma, paired with sign gamma
34+
@scalar_rule(SpecialFunctions.logabsgamma(x), SpecialFunctions.digamma(x), Zero())
35+
end
36+
37+
if isdefined(SpecialFunctions, :loggamma)
38+
@scalar_rule(SpecialFunctions.loggamma(x), SpecialFunctions.digamma(x))
39+
end
40+
2741
end #module

test/rulesets/Base/base.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,23 @@
166166
test_scalar(one, x)
167167
test_scalar(zero, x)
168168
end
169+
170+
@testset "sign" begin
171+
@testset "at points" for x in (-1.1, -1.1, 0.5, 100)
172+
test_scalar(sign, x)
173+
end
174+
175+
@testset "Zero over the point discontinuity" begin
176+
# Can't do finite differencing because we are lying
177+
# following the subgradient convention.
178+
179+
_, pb = rrule(sign, 0.0)
180+
_, x̄ = pb(10.5)
181+
@test extern(x̄) == 0
182+
183+
_, pf = frule(sign, 0.0)
184+
= pf(NamedTuple(), 10.5)
185+
@test extern(ẏ) == 0
186+
end
187+
end
169188
end

test/rulesets/packages/SpecialFunctions.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using SpecialFunctions
2-
2+
#=
33
@testset "SpecialFunctions" for x in (1, -1, 0, 0.5, 10, -17.1, 1.5 + 0.7im)
44
test_scalar(SpecialFunctions.erf, x)
55
test_scalar(SpecialFunctions.erfc, x)
@@ -31,6 +31,24 @@ using SpecialFunctions
3131
test_scalar(SpecialFunctions.gamma, x)
3232
test_scalar(SpecialFunctions.digamma, x)
3333
test_scalar(SpecialFunctions.trigamma, x)
34-
test_scalar(SpecialFunctions.lgamma, x)
3534
end
3635
end
36+
=#
37+
@testset "log gamma and co" begin
38+
# SpecialFunctions 0.7->0.8 changes:
39+
for x in (1.5, 2.5, 10.5, -0.6, -2.6, 1.6+1.6im, 1.6-1.6im, -4.6+1.6im)
40+
if isdefined(SpecialFunctions, :lgamma)
41+
test_scalar(SpecialFunctions.lgamma, x)
42+
end
43+
if isdefined(SpecialFunctions, :loggamma)
44+
isreal(x) && x < 0 && continue
45+
test_scalar(SpecialFunctions.loggamma, x)
46+
end
47+
48+
if isdefined(SpecialFunctions, :logabsgamma)
49+
isreal(x) || continue
50+
51+
end
52+
end
53+
54+
end

test/runtests.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ println("Testing ChainRules.jl")
2222
@testset "rulesets" begin
2323
@testset "Base" begin
2424
include(joinpath("rulesets", "Base", "base.jl"))
25-
include(joinpath("rulesets", "Base", "array.jl"))
26-
include(joinpath("rulesets", "Base", "mapreduce.jl"))
27-
include(joinpath("rulesets", "Base", "broadcast.jl"))
25+
#include(joinpath("rulesets", "Base", "array.jl"))
26+
#include(joinpath("rulesets", "Base", "mapreduce.jl"))
27+
#include(joinpath("rulesets", "Base", "broadcast.jl"))
2828
end
29-
29+
#==
3030
print(" ")
3131
3232
@testset "Statistics" begin
@@ -41,6 +41,7 @@ println("Testing ChainRules.jl")
4141
include(joinpath("rulesets", "LinearAlgebra", "factorization.jl"))
4242
include(joinpath("rulesets", "LinearAlgebra", "blas.jl"))
4343
end
44+
==#
4445

4546
print(" ")
4647

0 commit comments

Comments
 (0)