Skip to content

Commit 0a515fc

Browse files
committed
Native RNG: Support large arrays by avoiding counter overflowing.
1 parent 54bf5b0 commit 0a515fc

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/random.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ function Random.rand!(rng::RNG, A::AnyCuArray)
5252

5353
# grid-stride loop
5454
threadId = threadIdx().x
55-
window = blockDim().x * gridDim().x
56-
offset = (blockIdx().x - 1) * blockDim().x
55+
window = widemul(blockDim().x, gridDim().x)
56+
offset = widemul(blockIdx().x - 1i32, blockDim().x)
5757
while offset < length(A)
5858
i = threadId + offset
5959
if i <= length(A)
@@ -96,8 +96,8 @@ function Random.randn!(rng::RNG, A::AnyCuArray{<:Union{AbstractFloat,Complex{<:A
9696

9797
# grid-stride loop
9898
threadId = threadIdx().x
99-
window = (blockDim().x - 1) * gridDim().x
100-
offset = (blockIdx().x - 1) * blockDim().x
99+
window = widemul(blockDim().x - 1i32, gridDim().x)
100+
offset = widemul(blockIdx().x - 1i32, blockDim().x)
101101
while offset < length(A)
102102
i = threadId + offset
103103
j = threadId + offset + window
@@ -129,8 +129,8 @@ function Random.randn!(rng::RNG, A::AnyCuArray{<:Union{AbstractFloat,Complex{<:A
129129

130130
# grid-stride loop
131131
threadId = threadIdx().x
132-
window = (blockDim().x - 1) * gridDim().x
133-
offset = (blockIdx().x - 1) * blockDim().x
132+
window = widemul(blockDim().x - 1i32, gridDim().x)
133+
offset = widemul(blockIdx().x - 1i32, blockDim().x)
134134
while offset < length(A)
135135
i = threadId + offset
136136
if i <= length(A)

test/base/random.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,9 @@ end
198198
end
199199
end
200200

201+
@testset "counter overflow" begin
202+
rng = CUDA.RNG()
203+
c = CUDA.zeros(Float16, (64, 32, 512, 32, 64))
204+
rand!(rng, c)
205+
randn!(rng, c)
206+
end

0 commit comments

Comments
 (0)