diff --git a/Project.toml b/Project.toml index 66816490..e59dd9ca 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" diff --git a/src/CuArrays.jl b/src/CuArrays.jl index dcb27410..528bd722 100644 --- a/src/CuArrays.jl +++ b/src/CuArrays.jl @@ -6,7 +6,7 @@ using GPUArrays export CuArray, CuVector, CuMatrix, CuVecOrMat, cu -import LinearAlgebra +import LinearAlgebra, SpecialFunctions using Adapt @@ -31,6 +31,7 @@ include("array.jl") include("subarray.jl") include("utils.jl") include("indexing.jl") +include("special/gamma.jl") include("broadcast.jl") include("matmul.jl") include("mapreduce.jl") diff --git a/src/broadcast.jl b/src/broadcast.jl index 87cb7679..b4e56979 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -40,6 +40,11 @@ for f in libdevice @eval cufunc(::typeof(Base.$f)) = CUDAnative.$f end +cufunc(::typeof(SpecialFunctions.lbeta)) = CuArrays.lbeta +cufunc(::typeof(SpecialFunctions.lgamma)) = CuArrays.lgamma +cufunc(::typeof(SpecialFunctions.digamma)) = CuArrays.digamma +cufunc(::typeof(SpecialFunctions.trigamma)) = CuArrays.trigamma + #broadcast ^ culiteral_pow(::typeof(^), x::Union{Float32, Float64}, ::Val{0}) = one(x) culiteral_pow(::typeof(^), x::Union{Float32, Float64}, ::Val{1}) = x diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index f5328909..5da13f8f 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -37,9 +37,20 @@ ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :abs, 1)] = x -> :(signbit(x) ? -one(x) : one(x)) eval(ForwardDiff.unary_dual_definition(:CUDAnative, :abs)) +# byhand: lgamma +ForwardDiff.DiffRules.@define_diffrule CuArrays.lgamma(a) = :(CuArrays.digamma($a)) +eval(ForwardDiff.unary_dual_definition(:CuArrays, :lgamma)) -ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :pow, 2)] = (x, y) -> - replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y)) +# byhand: digamma +ForwardDiff.DiffRules.@define_diffrule CuArrays.digamma(a) = :(CuArrays.trigamma($a)) +eval(ForwardDiff.unary_dual_definition(:CuArrays, :digamma)) + +# byhand: lbeta +ForwardDiff.DiffRules.@define_diffrule CuArrays.lbeta(a, b) = :(CuArrays.digamma($a) - CuArrays.digamma($a + $b)), :(CuArrays.digamma($b) - CuArrays.digamma($a + $b)) +eval(ForwardDiff.binary_dual_definition(:CuArrays, :lbeta)) + +ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :pow, 2)] = + (x, y) -> replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y)) @eval begin ForwardDiff.@define_binary_dual_op( diff --git a/src/special/gamma.jl b/src/special/gamma.jl new file mode 100644 index 00000000..e0090f88 --- /dev/null +++ b/src/special/gamma.jl @@ -0,0 +1,63 @@ +# This file is heavlily adopted from https://github.com/JuliaMath/SpecialFunctions.jl. +# License is MIT: http://julialang.org/license + +function lgamma(x) + return CUDAnative.lgamma(x) +end + +function digamma(x) + if x <= 0 # reflection formula + ψ = -π / CUDAnative.tan(π * x) + x = 1 - x + else + ψ = zero(x) + end + if x < 7 + # shift using recurrence formula + ν = one(x) + n = 7 - CUDAnative.floor(x) + while ν <= n - 1 + ψ -= inv(x + ν) + ν += one(x) + end + ψ -= inv(x) + x += n + end + t = inv(x) + ψ += CUDAnative.log(x) - 0.5 * t + t *= t # 1/z^2 + # the coefficients here are Float64(bernoulli[2:9] .// (2*(1:8))) + ψ -= t * @evalpoly(t,0.08333333333333333,-0.008333333333333333,0.003968253968253968,-0.004166666666666667,0.007575757575757576,-0.021092796092796094,0.08333333333333333,-0.4432598039215686) + return ψ +end + +function _trigamma(x) + ψ = zero(x) + if x < 8 + # shift using recurrence formula + n = 8 - CUDAnative.floor(x) + ψ += inv(x)^2 + ν = one(x) + while ν <= n - 1 + ψ += inv(x + ν)^2 + ν += one(x) + end + x += n + end + t = inv(x) + w = t * t # 1/z^2 + ψ += t + 0.5 * w + # the coefficients here are Float64(bernoulli[2:9]) + ψ += t * w * @evalpoly(w,0.16666666666666666,-0.03333333333333333,0.023809523809523808,-0.03333333333333333,0.07575757575757576,-0.2531135531135531,1.1666666666666667,-7.092156862745098) + return ψ +end + +function trigamma(x) + if x <= 0 # reflection formula + return (π / CUDAnative.sin(π * x))^2 - _trigamma(1 - x) + else + return _trigamma(x) + end +end + +lbeta(x, y) = lgamma(x) + lgamma(y) - lgamma(x + y) diff --git a/test/runtests.jl b/test/runtests.jl index f256a793..2c027334 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,7 @@ include("fft.jl") include("sparse.jl") include("solver.jl") include("sparse_solver.jl") +include("special.jl") include("dnn.jl") include("forwarddiff.jl") diff --git a/test/special.jl b/test/special.jl new file mode 100644 index 00000000..5be974c9 --- /dev/null +++ b/test/special.jl @@ -0,0 +1,96 @@ +using Test +import SpecialFunctions +using Flux: Tracker +using CuArrays + +n = 1000 + +xs_lgamma = randn(Float32, n); xs_lgamma_cu = cu(xs_lgamma) +xs_digamma = randn(Float32, n); xs_digamma_cu = cu(xs_digamma) +xs_trigamma = randn(Float32, n); xs_trigamma_cu = cu(xs_trigamma) +xs_lbeta_tuple = (randn(Float32, n), randn(Float32, n)) +xs_lbeta_tuple = map(xs -> abs.(xs), xs_lbeta_tuple); xs_lbeta_cu_tuple = map(cu, xs_lbeta_tuple) + +catgrads(grads) = cat(map(ta -> ta.data, grads)...; dims=1) +g∑fx(f, xs) = catgrads(Tracker.gradient(_xs -> sum(f.(_xs)), xs)) +g∑fx(f, xs, ys) = catgrads(Tracker.gradient((_xs, _ys) -> sum(f.(_xs, _ys)), xs, ys)) + +results = Dict() +@testset "Forward evaluation" begin + fn = :lgamma + @testset "$fn" begin + lgamma_val_cpu = @time SpecialFunctions.lgamma.(xs_lgamma) + lgamma_val_gpu = @time CuArrays.lgamma.(xs_lgamma_cu) + lgamma_val_gpu = Array(lgamma_val_gpu) + for i = 1:n + @test lgamma_val_cpu[i] ≈ lgamma_val_gpu[i] + end + results[fn] = (lgamma_val_cpu, lgamma_val_gpu) + end + + fn = :digamma + @testset "$fn" begin + digamma_val_cpu = @time SpecialFunctions.digamma.(xs_digamma) + digamma_val_gpu = @time CuArrays.digamma.(xs_digamma_cu) + digamma_val_gpu = Array(digamma_val_gpu) + for i = 1:n + @test digamma_val_cpu[i] ≈ digamma_val_gpu[i] + end + results[fn] = (digamma_val_cpu, digamma_val_gpu) + end + + fn = :trigamma + @testset "$fn" begin + trigamma_val_cpu = @time SpecialFunctions.trigamma.(xs_trigamma) + trigamma_val_gpu = @time CuArrays.trigamma.(xs_trigamma_cu) + trigamma_val_gpu = Array(trigamma_val_gpu) + for i = 1:n + @test trigamma_val_cpu[i] ≈ trigamma_val_gpu[i] + end + results[fn] = (trigamma_val_cpu, trigamma_val_gpu) + end + + fn = :lbeta + @testset "$fn" begin + lbeta_val_cpu = @time SpecialFunctions.lbeta.(xs_lbeta_tuple...) + lbeta_val_gpu = @time CuArrays.lbeta.(xs_lbeta_cu_tuple...) + lbeta_val_gpu = Array(lbeta_val_gpu) + for i = 1:n + @test lbeta_val_cpu[i] ≈ lbeta_val_gpu[i] + end + results[fn] = (lbeta_val_cpu, lbeta_val_gpu) + end + +end + +@testset "Gradient evaluation" begin + fn = :lgamma + @testset "$fn" begin + lgamma_grad_cpu = @time g∑fx(SpecialFunctions.lgamma, xs_lgamma) + lgamma_grad_gpu = @time g∑fx(CuArrays.lgamma, xs_lgamma_cu) + lgamma_grad_gpu = Array(lgamma_grad_gpu) + for i = 1:n + @test lgamma_grad_cpu[i] ≈ lgamma_grad_gpu[i] + end + end + + fn = :digamma + @testset "$fn" begin + digamma_grad_cpu = @time g∑fx(SpecialFunctions.digamma, xs_digamma) + digamma_grad_gpu = @time g∑fx(CuArrays.digamma, xs_digamma_cu) + digamma_grad_gpu = Array(digamma_grad_gpu) + for i = 1:n + @test digamma_grad_cpu[i] ≈ digamma_grad_gpu[i] + end + end + + fn = :lbeta + @testset "$fn" begin + lbeta_grad_cpu = @time g∑fx(SpecialFunctions.lbeta, xs_lbeta_tuple...) + lbeta_grad_gpu = @time g∑fx(CuArrays.lbeta, xs_lbeta_cu_tuple...) + lbeta_grad_gpu = Array(lbeta_grad_gpu) + for i = 1:n + @test lbeta_grad_cpu[i] ≈ lbeta_grad_gpu[i] + end + end +end