Skip to content

Commit eb4ec03

Browse files
committed
Fix bad sqrt override
1 parent d602193 commit eb4ec03

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

src/device/intrinsics/math.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,7 @@ end
396396

397397
@device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x)
398398
@device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x)
399-
@device_override function Base.sqrt(x::Float16)
400-
if compute_capability() >= sv"8.0"
401-
ccall("extern __nv_hsqrt", llvmcall, Float16, (Float16,), x)
402-
else
403-
return Float16(sqrt(Float32(x)))
404-
end
405-
end
399+
@device_override function Base.sqrt(x::Float16) = Float16(sqrt(Float32(x)))
406400
@device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x)
407401

408402
@device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x)

0 commit comments

Comments
 (0)