Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit d92300d

Browse files
committed
Try #321:
2 parents dadee48 + 3fc544e commit d92300d

File tree

7 files changed

+181
-3
lines changed

7 files changed

+181
-3
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
1010
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
1111
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1314
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1415
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1516
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

src/CuArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using GPUArrays
66

77
export CuArray, CuVector, CuMatrix, CuVecOrMat, cu
88

9-
import LinearAlgebra
9+
import LinearAlgebra, SpecialFunctions
1010

1111
using Adapt
1212

@@ -31,6 +31,7 @@ include("array.jl")
3131
include("subarray.jl")
3232
include("utils.jl")
3333
include("indexing.jl")
34+
include("special/gamma.jl")
3435
include("broadcast.jl")
3536
include("matmul.jl")
3637
include("mapreduce.jl")

src/broadcast.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ for f in libdevice
4040
@eval cufunc(::typeof(Base.$f)) = CUDAnative.$f
4141
end
4242

43+
cufunc(::typeof(SpecialFunctions.lbeta)) = CuArrays.lbeta
44+
cufunc(::typeof(SpecialFunctions.lgamma)) = CuArrays.lgamma
45+
cufunc(::typeof(SpecialFunctions.digamma)) = CuArrays.digamma
46+
cufunc(::typeof(SpecialFunctions.trigamma)) = CuArrays.trigamma
47+
4348
#broadcast ^
4449
culiteral_pow(::typeof(^), x::Union{Float32, Float64}, ::Val{0}) = one(x)
4550
culiteral_pow(::typeof(^), x::Union{Float32, Float64}, ::Val{1}) = x

src/forwarddiff.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,20 @@ ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :abs, 1)] = x ->
3737
:(signbit(x) ? -one(x) : one(x))
3838
eval(ForwardDiff.unary_dual_definition(:CUDAnative, :abs))
3939

40+
# byhand: lgamma
41+
ForwardDiff.DiffRules.@define_diffrule CuArrays.lgamma(a) = :(CuArrays.digamma($a))
42+
eval(ForwardDiff.unary_dual_definition(:CuArrays, :lgamma))
4043

41-
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :pow, 2)] = (x, y) ->
42-
replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y))
44+
# byhand: digamma
45+
ForwardDiff.DiffRules.@define_diffrule CuArrays.digamma(a) = :(CuArrays.trigamma($a))
46+
eval(ForwardDiff.unary_dual_definition(:CuArrays, :digamma))
47+
48+
# byhand: lbeta
49+
ForwardDiff.DiffRules.@define_diffrule CuArrays.lbeta(a, b) = :(CuArrays.digamma($a) - CuArrays.digamma($a + $b)), :(CuArrays.digamma($b) - CuArrays.digamma($a + $b))
50+
eval(ForwardDiff.binary_dual_definition(:CuArrays, :lbeta))
51+
52+
ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :pow, 2)] =
53+
(x, y) -> replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y))
4354

4455
@eval begin
4556
ForwardDiff.@define_binary_dual_op(

src/special/gamma.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# This file is heavlily adopted from https://github.com/JuliaMath/SpecialFunctions.jl.
2+
# License is MIT: http://julialang.org/license
3+
4+
function lgamma(x)
5+
return CUDAnative.lgamma(x)
6+
end
7+
8+
function digamma(x)
9+
if x <= 0 # reflection formula
10+
ψ = -π / CUDAnative.tan* x)
11+
x = 1 - x
12+
else
13+
ψ = zero(x)
14+
end
15+
if x < 7
16+
# shift using recurrence formula
17+
ν = one(x)
18+
n = 7 - CUDAnative.floor(x)
19+
while ν <= n - 1
20+
ψ -= inv(x + ν)
21+
ν += one(x)
22+
end
23+
ψ -= inv(x)
24+
x += n
25+
end
26+
t = inv(x)
27+
ψ += CUDAnative.log(x) - 0.5 * t
28+
t *= t # 1/z^2
29+
# the coefficients here are Float64(bernoulli[2:9] .// (2*(1:8)))
30+
ψ -= t * @evalpoly(t,0.08333333333333333,-0.008333333333333333,0.003968253968253968,-0.004166666666666667,0.007575757575757576,-0.021092796092796094,0.08333333333333333,-0.4432598039215686)
31+
return ψ
32+
end
33+
34+
function _trigamma(x)
35+
ψ = zero(x)
36+
if x < 8
37+
# shift using recurrence formula
38+
n = 8 - CUDAnative.floor(x)
39+
ψ += inv(x)^2
40+
ν = one(x)
41+
while ν <= n - 1
42+
ψ += inv(x + ν)^2
43+
ν += one(x)
44+
end
45+
x += n
46+
end
47+
t = inv(x)
48+
w = t * t # 1/z^2
49+
ψ += t + 0.5 * w
50+
# the coefficients here are Float64(bernoulli[2:9])
51+
ψ += t * w * @evalpoly(w,0.16666666666666666,-0.03333333333333333,0.023809523809523808,-0.03333333333333333,0.07575757575757576,-0.2531135531135531,1.1666666666666667,-7.092156862745098)
52+
return ψ
53+
end
54+
55+
function trigamma(x)
56+
if x <= 0 # reflection formula
57+
return/ CUDAnative.sin* x))^2 - _trigamma(1 - x)
58+
else
59+
return _trigamma(x)
60+
end
61+
end
62+
63+
lbeta(x, y) = lgamma(x) + lgamma(y) - lgamma(x + y)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ include("fft.jl")
2424
include("sparse.jl")
2525
include("solver.jl")
2626
include("sparse_solver.jl")
27+
include("special.jl")
2728
include("dnn.jl")
2829
include("forwarddiff.jl")
2930

test/special.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
using Test
2+
import SpecialFunctions
3+
using Flux: Tracker
4+
using CuArrays
5+
6+
n = 1000
7+
8+
xs_lgamma = randn(Float32, n); xs_lgamma_cu = cu(xs_lgamma)
9+
xs_digamma = randn(Float32, n); xs_digamma_cu = cu(xs_digamma)
10+
xs_trigamma = randn(Float32, n); xs_trigamma_cu = cu(xs_trigamma)
11+
xs_lbeta_tuple = (randn(Float32, n), randn(Float32, n))
12+
xs_lbeta_tuple = map(xs -> abs.(xs), xs_lbeta_tuple); xs_lbeta_cu_tuple = map(cu, xs_lbeta_tuple)
13+
14+
catgrads(grads) = cat(map(ta -> ta.data, grads)...; dims=1)
15+
g∑fx(f, xs) = catgrads(Tracker.gradient(_xs -> sum(f.(_xs)), xs))
16+
g∑fx(f, xs, ys) = catgrads(Tracker.gradient((_xs, _ys) -> sum(f.(_xs, _ys)), xs, ys))
17+
18+
results = Dict()
19+
@testset "Forward evaluation" begin
20+
fn = :lgamma
21+
@testset "$fn" begin
22+
lgamma_val_cpu = @time SpecialFunctions.lgamma.(xs_lgamma)
23+
lgamma_val_gpu = @time CuArrays.lgamma.(xs_lgamma_cu)
24+
lgamma_val_gpu = Array(lgamma_val_gpu)
25+
for i = 1:n
26+
@test lgamma_val_cpu[i] lgamma_val_gpu[i]
27+
end
28+
results[fn] = (lgamma_val_cpu, lgamma_val_gpu)
29+
end
30+
31+
fn = :digamma
32+
@testset "$fn" begin
33+
digamma_val_cpu = @time SpecialFunctions.digamma.(xs_digamma)
34+
digamma_val_gpu = @time CuArrays.digamma.(xs_digamma_cu)
35+
digamma_val_gpu = Array(digamma_val_gpu)
36+
for i = 1:n
37+
@test digamma_val_cpu[i] digamma_val_gpu[i]
38+
end
39+
results[fn] = (digamma_val_cpu, digamma_val_gpu)
40+
end
41+
42+
fn = :trigamma
43+
@testset "$fn" begin
44+
trigamma_val_cpu = @time SpecialFunctions.trigamma.(xs_trigamma)
45+
trigamma_val_gpu = @time CuArrays.trigamma.(xs_trigamma_cu)
46+
trigamma_val_gpu = Array(trigamma_val_gpu)
47+
for i = 1:n
48+
@test trigamma_val_cpu[i] trigamma_val_gpu[i]
49+
end
50+
results[fn] = (trigamma_val_cpu, trigamma_val_gpu)
51+
end
52+
53+
fn = :lbeta
54+
@testset "$fn" begin
55+
lbeta_val_cpu = @time SpecialFunctions.lbeta.(xs_lbeta_tuple...)
56+
lbeta_val_gpu = @time CuArrays.lbeta.(xs_lbeta_cu_tuple...)
57+
lbeta_val_gpu = Array(lbeta_val_gpu)
58+
for i = 1:n
59+
@test lbeta_val_cpu[i] lbeta_val_gpu[i]
60+
end
61+
results[fn] = (lbeta_val_cpu, lbeta_val_gpu)
62+
end
63+
64+
end
65+
66+
@testset "Gradient evaluation" begin
67+
fn = :lgamma
68+
@testset "$fn" begin
69+
lgamma_grad_cpu = @time g∑fx(SpecialFunctions.lgamma, xs_lgamma)
70+
lgamma_grad_gpu = @time g∑fx(CuArrays.lgamma, xs_lgamma_cu)
71+
lgamma_grad_gpu = Array(lgamma_grad_gpu)
72+
for i = 1:n
73+
@test lgamma_grad_cpu[i] lgamma_grad_gpu[i]
74+
end
75+
end
76+
77+
fn = :digamma
78+
@testset "$fn" begin
79+
digamma_grad_cpu = @time g∑fx(SpecialFunctions.digamma, xs_digamma)
80+
digamma_grad_gpu = @time g∑fx(CuArrays.digamma, xs_digamma_cu)
81+
digamma_grad_gpu = Array(digamma_grad_gpu)
82+
for i = 1:n
83+
@test digamma_grad_cpu[i] digamma_grad_gpu[i]
84+
end
85+
end
86+
87+
fn = :lbeta
88+
@testset "$fn" begin
89+
lbeta_grad_cpu = @time g∑fx(SpecialFunctions.lbeta, xs_lbeta_tuple...)
90+
lbeta_grad_gpu = @time g∑fx(CuArrays.lbeta, xs_lbeta_cu_tuple...)
91+
lbeta_grad_gpu = Array(lbeta_grad_gpu)
92+
for i = 1:n
93+
@test lbeta_grad_cpu[i] lbeta_grad_gpu[i]
94+
end
95+
end
96+
end

0 commit comments

Comments
 (0)