Skip to content

Commit 76eacff

Browse files
authored
Avoid ReshapedArray using Int128 in kernel (#449)
1 parent b06f0fe commit 76eacff

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

src/device/opencl/math.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,25 @@ end
182182

183183

184184
# TODO: half and native
185+
186+
function _mulhi(a::Int64, b::Int64)
187+
shift = sizeof(a) * 4
188+
mask = typemax(UInt32)
189+
a1, a2 = (a >> shift), a & mask
190+
b1, b2 = (b >> shift), b & mask
191+
a1b1, a1b2, a2b1 = a1*b1, a1*b2, a2*b1
192+
t1 = a1b2 + _mulhi(a2 % UInt32, b2 % UInt32)
193+
t2 = a2b1 + (t1 & mask)
194+
a1b1 + (t1 >> shift) + (t2 >> shift)
195+
end
196+
@static if isdefined(Base.MultiplicativeInverses, :_mul_high)
197+
_mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = Base.MultiplicativeInverses._mul_high(a, b)
198+
@device_override Base.MultiplicativeInverses._mul_high(a::Int64, b::Int64) = _mulhi(a, b)
199+
else
200+
_mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = ((widen(a)*b) >>> (sizeof(a)*8)) % T
201+
@device_override function Base.div(a::Int64, b::Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64})
202+
x = _mulhi(a, b.multiplier)
203+
x += (a*b.addmul) % Int64
204+
ifelse(abs(b.divisor) == 1, a*b.divisor, (signbit(x) + (x >> b.shift)) % Int64)
205+
end
206+
end

0 commit comments

Comments
 (0)