Skip to content

Commit 1b26071

Browse files
authored
Merge pull request #130 from jakebolewski/jcb/fix_julia_15
Rewrite expoenent implementation for the CUDA backend with Cassette for Julia 1.5 compatibility
2 parents b728db6 + 934ad2c commit 1b26071

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1414
[compat]
1515
Adapt = "0.4, 1.0, 2.0"
1616
CUDA = "~1.0, ~1.1, ~1.2, 1.3"
17-
Cassette = "0.3.2"
17+
Cassette = "0.3.3"
1818
MacroTools = "0.5"
1919
SpecialFunctions = "0.10"
2020
StaticArrays = "0.12"

src/backends/cuda.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=nothin
139139
ndrange = (ndrange,)
140140
end
141141
if workgroupsize isa Integer
142-
workgroupsize = (workgroupsize, )
142+
workgroupsize = (workgroupsize,)
143143
end
144144

145145
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
@@ -256,6 +256,23 @@ end
256256
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
257257
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
258258

259+
@inline function Cassette.overdub(::CUDACtx, ::typeof(exponent), x::Union{Float32, Float64})
260+
T = typeof(x)
261+
xs = reinterpret(Unsigned, x) & ~Base.sign_mask(T)
262+
if xs >= Base.exponent_mask(T)
263+
throw(DomainError(x, "Cannot be Nan of Inf."))
264+
end
265+
k = Int(xs >> Base.significand_bits(T))
266+
if k == 0 # x is subnormal
267+
if xs == 0
268+
throw(DomainError(x, "Cannot be subnormal converted to 0."))
269+
end
270+
m = Base.leading_zeros(xs) - Base.exponent_bits(T)
271+
k = 1 - m
272+
end
273+
return k - Base.exponent_bias(T)
274+
end
275+
259276
@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
260277
const emit_shmem = CUDA.emit_shmem
261278
else

src/compiler/pass.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ function transform(ctx, ref)
4141
end
4242

4343
# overdubbing IntrinsicFunctions removes our ability to profile code
44+
# Base.eltype is special because it is used for type inference
4445
newstmt = (x, i) -> begin
4546
isassign = Base.Meta.isexpr(x, :(=))
4647
stmt = isassign ? x.args[2] : x
@@ -61,7 +62,7 @@ function transform(ctx, ref)
6162
name = f.name
6263
if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name)
6364
ff = getfield(f.mod, f.name)
64-
if ff isa Core.IntrinsicFunction || ff isa Core.Builtin
65+
if ff isa Core.IntrinsicFunction || ff isa Core.Builtin || ff === Base.eltype
6566
stmt.args[fidx] = Expr(:nooverdub, f)
6667
end
6768
end
@@ -76,4 +77,4 @@ function transform(ctx, ref)
7677
return CI
7778
end
7879

79-
const CompilerPass = Cassette.@pass transform
80+
const CompilerPass = Cassette.@pass transform

0 commit comments

Comments
 (0)