Skip to content

Commit a14b50b

Browse files
Merge #115
115: add erf + erfc functions r=vchuravy a=simonbyrne Needed for CliMA/ClimateMachine.jl#1392 Co-authored-by: Simon Byrne <simonbyrne@gmail.com>
2 parents 5891656 + 2d8965b commit a14b50b

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

src/backends/cpu.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ for f in cpufuns
222222
end
223223

224224
@inline Cassette.overdub(::CPUCtx, ::typeof(SpecialFunctions.gamma), x::Union{Float32, Float64}) = SpecialFunctions.gamma(x)
225+
@inline Cassette.overdub(::CPUCtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = SpecialFunctions.erf(x)
226+
@inline Cassette.overdub(::CPUCtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = SpecialFunctions.erfc(x)
225227

226228
###
227229
# CPU implementation of shared memory

src/backends/cuda.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@ end
292292
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = CUDA.exp(x)
293293

294294
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.gamma), x::Union{Float32, Float64}) = CUDA.tgamma(x)
295+
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
296+
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
295297

296298
###
297299
# GPU implementation of shared memory

test/test.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,55 @@ end
295295
@test cy[4] SpecialFunctions.gamma.(x[4])
296296
end
297297
end
298+
299+
@testset "special functions: erf" begin
300+
import SpecialFunctions
301+
302+
@kernel function erf_knl(A, @Const(B))
303+
I = @index(Global)
304+
@inbounds A[I] = SpecialFunctions.erf(B[I])
305+
end
306+
307+
x = [-1.0,-0.5,0.0,1e-3,1.0,2.0,5.5]
308+
y = similar(x)
309+
event = erf_knl(CPU())(y, x; ndrange=length(x))
310+
wait(event)
311+
@test y == SpecialFunctions.erf.(x)
312+
313+
if has_cuda_gpu()
314+
cx = CuArray(x)
315+
cy = similar(cx)
316+
event = erf_knl(CUDADevice())(cy, cx; ndrange=length(x))
317+
wait(event)
318+
319+
cy = Array(cy)
320+
@test cy[1:3] == SpecialFunctions.erf.(x[1:3])
321+
@test cy[4] SpecialFunctions.erf.(x[4])
322+
end
323+
end
324+
325+
@testset "special functions: erfc" begin
326+
import SpecialFunctions
327+
328+
@kernel function erfc_knl(A, @Const(B))
329+
I = @index(Global)
330+
@inbounds A[I] = SpecialFunctions.erfc(B[I])
331+
end
332+
333+
x = [-1.0,-0.5,0.0,1e-3,1.0,2.0,5.5]
334+
y = similar(x)
335+
event = erfc_knl(CPU())(y, x; ndrange=length(x))
336+
wait(event)
337+
@test y == SpecialFunctions.erfc.(x)
338+
339+
if has_cuda_gpu()
340+
cx = CuArray(x)
341+
cy = similar(cx)
342+
event = erfc_knl(CUDADevice())(cy, cx; ndrange=length(x))
343+
wait(event)
344+
345+
cy = Array(cy)
346+
@test cy[1:3] == SpecialFunctions.erfc.(x[1:3])
347+
@test cy[4] SpecialFunctions.erfc.(x[4])
348+
end
349+
end

0 commit comments

Comments
 (0)