Skip to content

Commit 6481ae2

Browse files
authored
Allow use of CPU RNG without scalar iteration. (#378)
1 parent c09f370 commit 6481ae2

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/host/random.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,9 @@ function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
109109
end
110110
A
111111
end
112+
113+
# allow use of CPU RNGs without scalar iteration
114+
Random.rand!(rng::AbstractRNG, A::AnyGPUArray) =
115+
copyto!(A, rand(rng, eltype(A), size(A)...))
116+
Random.randn!(rng::AbstractRNG, A::AnyGPUArray) =
117+
copyto!(A, randn(rng, eltype(A), size(A)...))

test/testsuite/random.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
else
55
Random.default_rng()
66
end
7+
cpu_rng = Random.default_rng()
78

89
@testset "rand" begin # uniform
910
for T in eltypes, d in (10, (10,10))
@@ -19,6 +20,10 @@
1920
Random.seed!(rng, 1)
2021
rand!(rng, B)
2122
@test all(Array(A) .== Array(B))
23+
24+
if rng != cpu_rng
25+
rand!(cpu_rng, A)
26+
end
2227
end
2328

2429
A = AT{Bool}(undef, 1024)
@@ -46,6 +51,10 @@
4651
Random.seed!(rng, 1)
4752
randn!(rng, B)
4853
@test all(Array(A) .== Array(B))
54+
55+
if rng != cpu_rng
56+
randn!(cpu_rng, A)
57+
end
4958
end
5059
end
5160
end

0 commit comments

Comments
 (0)