Skip to content

Commit b543f19

Browse files
committed
Limit dispatch of Random.rand! for ChaChaStream and CUDAChaChaStream
Change type signature of Random.rand! implementations for ChaChaStream and CUDAChaChaStream so that they don't accept an arbitrary AbstractVector, only Vector. Both of these methods actually require a Vector for the way that they're implemented (technically they could accept other inputs, but there isn't a well-defined interface for what those inputs should do). With this change, errors from incorrect calls to Random.rand! should be a little clearer.
1 parent 707c5ac commit b543f19

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

src/generation.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Random.rand!(rng::AbstractChaChaStream, A::CuArray{<:BitInteger}) =
3030

3131
Random.rand!(
3232
rng::AbstractChaChaStream,
33-
A::AbstractArray{T},
33+
A::Array{T},
3434
::Type{T},
3535
) where {T <: BitInteger} =
3636
(Random.rand!(rng, vec(A), T); A)
@@ -46,12 +46,7 @@ Random.rand(rng::ChaChaStream, T::Random.SamplerType{<:BitInteger}) =
4646
_fetch_one!(rng, T[])
4747

4848
# Inplace operations
49-
function Random.rand!(
50-
rng::ChaChaStream,
51-
A::AbstractVector{T},
52-
::Type{T}
53-
) where {T <: BitInteger}
54-
49+
function Random.rand!(rng::ChaChaStream, A::Vector{T}, ::Type{T}) where {T <: BitInteger}
5550
# Reinterpret the array as a byte array
5651
@GC.preserve A begin
5752
p = pointer(A)
@@ -68,11 +63,7 @@ end
6863
RNG methods for CUDAChaChaStream
6964
=#
7065

71-
function Random.rand!(
72-
rng::CUDAChaChaStream,
73-
A::AbstractVector{T},
74-
::Type{T}
75-
) where {T <: BitInteger}
66+
function Random.rand!(rng::CUDAChaChaStream, A::Vector{T}, ::Type{T}) where {T <: BitInteger}
7667
# Perform sampling on GPU and then copy to CPU
7768
A_gpu = CuVector{T}(undef, length(A))
7869
Random.rand!(rng, A_gpu, T)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
44
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
67
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
78
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/test_keystream.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
using ChaChaCiphers
44
using Random
5+
using SparseArrays
56
using Statistics
67
using Test
78

@@ -96,6 +97,21 @@ using Test
9697
@test randn(stream, 1500) == randn(stream_copy, 1500)
9798
end
9899

100+
@testset "Generate random sparse array" begin
101+
rng = ChaCha12Stream()
102+
x = sprand(rng, Int32, 400, 400, 0.1)
103+
104+
@test x isa AbstractMatrix{Int32}
105+
@test x isa SparseMatrixCSC{Int32}
106+
@test isapprox(mean(x .== 0), 0.9, atol=1e-2)
107+
108+
Random.rand!(rng, x)
109+
110+
@test x isa AbstractMatrix{Int32}
111+
@test x isa SparseMatrixCSC{Int32}
112+
@test (x .!= 0) |> all
113+
end
114+
99115
@testset "Encrypt data with a keystream" begin
100116
# Ref: IETF RFC 8439, Sec. A.2
101117
# https://datatracker.ietf.org/doc/html/rfc8439#appendix-A.2

0 commit comments

Comments
 (0)