Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 61c66f4

Browse files
committed
bugfix
1 parent 87733ac commit 61c66f4

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

src/forwarddiff.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# ForwardDiff integration
22

3-
byhand = [:lgamma, :digamma, :lbeta, :exp2, :log2, :exp10, :log10, :abs]
3+
byhand = [:exp2, :log2, :exp10, :log10, :abs]
44

55
for f in libdevice
66
if haskey(ForwardDiff.DiffRules.DEFINED_DIFFRULES, (:Base,f,1))
@@ -12,18 +12,6 @@ for f in libdevice
1212
end
1313
end
1414

15-
# byhand: lgamma
16-
ForwardDiff.DiffRules.@define_diffrule CuArrays.lgamma(a) = :(CuArrays.digamma($a))
17-
eval(ForwardDiff.unary_dual_definition(:CuArrays, :lgamma))
18-
19-
# byhand: digamma
20-
ForwardDiff.DiffRules.@define_diffrule CuArrays.digamma(a) = :(CuArrays.trigamma($a))
21-
eval(ForwardDiff.unary_dual_definition(:CuArrays, :digamma))
22-
23-
# byhand: lbeta
24-
ForwardDiff.DiffRules.@define_diffrule CuArrays.lbeta(a, b) = :(CuArrays.digamma($a) - CuArrays.digamma($a + $b)), :(CuArrays.digamma($b) - CuArrays.digamma($a + $b))
25-
eval(ForwardDiff.binary_dual_definition(:CuArrays, :lbeta))
26-
2715
# byhand: exp2
2816
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :exp2, 1)] = x ->
2917
:((CuArrays.cufunc(exp2))(x) * (CuArrays.cufunc(log))(oftype(x, 2)))
@@ -49,9 +37,20 @@ ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :abs, 1)] = x ->
4937
:(signbit(x) ? -one(x) : one(x))
5038
eval(ForwardDiff.unary_dual_definition(:CUDAnative, :abs))
5139

40+
# byhand: lgamma
41+
ForwardDiff.DiffRules.@define_diffrule CuArrays.lgamma(a) = :(CuArrays.digamma($a))
42+
eval(ForwardDiff.unary_dual_definition(:CuArrays, :lgamma))
43+
44+
# byhand: digamma
45+
ForwardDiff.DiffRules.@define_diffrule CuArrays.digamma(a) = :(CuArrays.trigamma($a))
46+
eval(ForwardDiff.unary_dual_definition(:CuArrays, :digamma))
47+
48+
# byhand: lbeta
49+
ForwardDiff.DiffRules.@define_diffrule CuArrays.lbeta(a, b) = :(CuArrays.digamma($a) - CuArrays.digamma($a + $b)), :(CuArrays.digamma($b) - CuArrays.digamma($a + $b))
50+
eval(ForwardDiff.binary_dual_definition(:CuArrays, :lbeta))
5251

53-
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :pow, 2)] = (x, y) ->
54-
replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y))
52+
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :pow, 2)] =
53+
(x, y) -> replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y))
5554

5655
@eval begin
5756
ForwardDiff.@define_binary_dual_op(

src/special/gamma.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,4 @@ function trigamma(x)
6060
end
6161
end
6262

63-
function lbeta(x, y)
64-
return CUDAnative.lgamma(x) + CUDAnative.lgamma(y) - CUDAnative.lgamma(x + y)
65-
end
63+
lbeta(x, y) = lgamma(x) + lgamma(y) - lgamma(x + y)

0 commit comments

Comments
 (0)