Skip to content

Commit c420cf0

Browse files
committed
Add the CUDAChaChaStream type
Add a new type, CUDAChaChaStream, for accessing the keystream of a ChaCha cipher on the GPU. This type can be used for GPU-based CRNG and encryption.
1 parent 926d794 commit c420cf0

File tree

7 files changed

+333
-10
lines changed

7 files changed

+333
-10
lines changed

src/ChaChaCiphers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ include("ChaCha.jl")
44

55
include("core.jl")
66
include("keystream.jl")
7+
include("cuda_keystream.jl")
78
include("generation.jl")
89

910
export ChaCha
1011
export ChaChaStream, ChaCha20Stream, ChaCha12Stream
12+
export CUDAChaChaStream, CUDAChaCha20Stream, CUDAChaCha12Stream
1113
export getstate
1214
export encrypt, decrypt, encrypt!, decrypt!
1315

src/cuda_keystream.jl

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
using Base: BitInteger
2+
using ChaChaCiphers.ChaCha
3+
using CUDA
4+
using StaticArrays
5+
6+
"""
7+
CUDAChaChaStream <: AbstractChaChaStream
8+
9+
`CUDAChaChaStream` is a CUDA-compatible ChaCha keystream
10+
generator for GPU CRNG.
11+
12+
## Examples
13+
14+
Create a `CUDAChaChaStream` with a randomized key, and
15+
sample some random numbers with it:
16+
17+
```@meta
18+
DocTestSetup = quote
19+
using CUDA
20+
using ChaChaCiphers
21+
using Random
22+
end
23+
```
24+
25+
```julia
26+
julia> rng = CUDAChaChaStream();
27+
28+
julia> x = CuVector{Float32}(undef, 2^10);
29+
```
30+
31+
```@meta
32+
DocTestSetup = nothing
33+
```
34+
35+
"""
36+
mutable struct CUDAChaChaStream <: AbstractChaChaStream
37+
key :: CuVector{UInt32}
38+
nonce :: UInt64
39+
counter :: UInt64
40+
buffer :: CuVector{UInt8}
41+
position :: Int
42+
doublerounds :: Int
43+
44+
function CUDAChaChaStream(
45+
key,
46+
nonce,
47+
counter = UInt64(0),
48+
position = 1;
49+
doublerounds = 10
50+
)
51+
if doublerounds 0
52+
error("`doublerounds` must be a positive number")
53+
end
54+
55+
key = CuVector{UInt32}(key)
56+
buffer = CuVector{UInt8}(undef, STREAM_BUFFER_SIZE)
57+
stream = new(key, nonce, counter, buffer, 1, doublerounds)
58+
_refresh_buffer!(stream)
59+
stream.position = position
60+
61+
stream
62+
end
63+
end
64+
65+
# Constructors
66+
67+
"""
68+
CUDAChaCha20Stream
69+
70+
Create a CUDA-compatible keystream for a ChaCha20 stream
71+
cipher.
72+
"""
73+
CUDAChaCha20Stream(args...) = CUDAChaChaStream(args...; doublerounds=10)
74+
75+
"""
76+
CUDAChaCha12Stream
77+
78+
Create a CUDA-compatible keystream for a ChaCha12 stream
79+
cipher.
80+
"""
81+
CUDAChaCha12Stream(args...) = CUDAChaChaStream(args...; doublerounds=6)
82+
83+
function Base.show(io::IO, rng::CUDAChaChaStream)
84+
msg = """
85+
CUDAChaChaStream(
86+
key = $(key(rng))
87+
nonce = $(nonce(rng)),
88+
counter = $(counter(rng)),
89+
rounds = $(2 * doublerounds(rng))
90+
)"""
91+
92+
write(io, msg)
93+
end
94+
95+
buffer_size(stream::CUDAChaChaStream) =
96+
length(stream.buffer) - stream.position + 1
97+
98+
# Methods required for AbstractChaChaStream compatibility
99+
100+
function key(stream::CUDAChaChaStream)
101+
key_cpu = Vector{UInt32}(undef, 8)
102+
copyto!(key_cpu, stream.key)
103+
SVector{8,UInt32}(key_cpu)
104+
end
105+
106+
@inline nonce(stream::CUDAChaChaStream) = stream.nonce
107+
@inline counter(stream::CUDAChaChaStream) = stream.counter
108+
@inline position(stream::CUDAChaChaStream) = stream.position
109+
@inline doublerounds(stream::CUDAChaChaStream) = stream.doublerounds
110+
111+
function _refresh_buffer!(stream::CUDAChaChaStream)
112+
_fill_blocks!(
113+
stream.buffer,
114+
stream,
115+
STREAM_BUFFER_BLOCKS
116+
)
117+
stream.position = 1
118+
stream
119+
end
120+
121+
function _fill_buffer!(dest::CuVector{UInt8}, stream::CUDAChaChaStream)
122+
bfsize = buffer_size(stream)
123+
destsize = length(dest)
124+
125+
# If the internal buffer is larger than the destination size,
126+
# we can just copy directly from the buffer to the stream and
127+
# return
128+
if bfsize >= destsize
129+
copyto!(dest, 1, stream.buffer, stream.position, destsize)
130+
stream.position += destsize
131+
return dest
132+
end
133+
134+
# Otherwise, the destination is larger than the buffer
135+
copyto!(dest, 1, stream.buffer, stream.position, bfsize)
136+
137+
(n_blocks, rem) = divrem(length(dest) - bfsize, CHACHA_BLOCK_SIZE)
138+
if n_blocks > 0
139+
sp = pointer(dest, bfsize + 1)
140+
slice = unsafe_wrap(CuVector{UInt8}, sp, n_blocks * CHACHA_BLOCK_SIZE)
141+
_fill_blocks!(slice, stream, n_blocks)
142+
end
143+
144+
# Refresh the stream, and then copy the stream buffer into the
145+
# remainder of the destination
146+
_refresh_buffer!(stream)
147+
_fill_buffer!(view(dest, length(dest)-rem+1:length(dest)), stream)
148+
149+
dest
150+
end
151+
152+
function _fill_blocks!(
153+
buffer::CuVector{T}, stream::CUDAChaChaStream, nblocks::Int
154+
) where {T <: BitInteger}
155+
156+
p = pointer(buffer)
157+
p = Base.unsafe_convert(CuPtr{UInt32}, p)
158+
buffer_u32 = unsafe_wrap(CuVector{UInt32}, p, nblocks * CHACHA_BLOCK_SIZE_U32)
159+
160+
stream.counter = chacha_blocks!(
161+
buffer_u32,
162+
stream.key,
163+
stream.nonce,
164+
stream.counter,
165+
nblocks
166+
)
167+
168+
buffer
169+
end
170+

src/generation.jl

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,54 @@ Implementation of the RNG API for ChaChaStream
44
55
=#
66

7+
using CUDA
78
using Random
89
using Random: SamplerType
910

10-
Random.rng_native_52(::ChaChaStream) = UInt64
11+
#=
12+
RNG methods for all subtypes of AbstractChaChaStream
13+
=#
14+
15+
Random.rng_native_52(::AbstractChaChaStream) = UInt64
16+
17+
# Support for different dimension specifications
18+
Random.rand(rng::AbstractChaChaStream, T::Type{<:BitInteger}) =
19+
Random.rand(rng, T, 1)[]
20+
Random.rand(rng::AbstractChaChaStream, T::Type{<:BitInteger}, dim1::Int, dims::Int...) =
21+
Random.rand(rng, T, Dims((dim1, dims...)))
22+
Random.rand(rng::AbstractChaChaStream, T::Type{<:BitInteger}, dims::Dims) =
23+
Random.rand!(rng, Array{T}(undef, dims))
24+
25+
# Inplace operations
26+
Random.rand!(rng::AbstractChaChaStream, A::AbstractArray{<:BitInteger}) =
27+
Random.rand!(rng, A, eltype(A))
28+
Random.rand!(rng::AbstractChaChaStream, A::CuArray{<:BitInteger}) =
29+
Random.rand!(rng, A, eltype(A))
30+
31+
Random.rand!(
32+
rng::AbstractChaChaStream,
33+
A::AbstractArray{T},
34+
::Type{T},
35+
) where {T <: BitInteger} =
36+
(Random.rand!(rng, vec(A), T); A)
37+
38+
#=
39+
RNG methods for ChaChaStream
40+
=#
1141

1242
Random.rand(rng::ChaChaStream, ::Type{T}) where {T <: BitInteger} =
1343
_fetch_one!(rng, T)
1444

1545
Random.rand(rng::ChaChaStream, T::Random.SamplerType{<:BitInteger}) =
1646
_fetch_one!(rng, T[])
1747

18-
Random.rand!(rng::ChaChaStream, A::Array) = Random.rand!(rng, A, eltype(A))
48+
# Inplace operations
49+
function Random.rand!(
50+
rng::ChaChaStream,
51+
A::AbstractVector{T},
52+
::Type{T}
53+
) where {T <: BitInteger}
1954

20-
function Random.rand!(rng::ChaChaStream, A::Array{T}, ::SamplerType{T}) where {T <: BitInteger}
21-
Random.rand!(rng, vec(A), T)
22-
A
23-
end
24-
25-
function Random.rand!(rng::ChaChaStream, A::Vector{T}, ::SamplerType{T}) where {T <: BitInteger}
2655
# Reinterpret the array as a byte array
2756
@GC.preserve A begin
2857
p = pointer(A)
@@ -35,4 +64,25 @@ function Random.rand!(rng::ChaChaStream, A::Vector{T}, ::SamplerType{T}) where {
3564
A
3665
end
3766

67+
#=
68+
RNG methods for CUDAChaChaStream
69+
=#
70+
71+
function Random.rand!(
72+
rng::CUDAChaChaStream,
73+
A::AbstractVector{T},
74+
::Type{T}
75+
) where {T <: BitInteger}
76+
# Perform sampling on GPU and then copy to CPU
77+
A_gpu = CuVector{T}(undef, length(A))
78+
Random.rand!(rng, A_gpu, T)
79+
copyto!(A, A_gpu)
80+
end
81+
82+
Random.rand!(
83+
rng::CUDAChaChaStream,
84+
A::CuVector{T},
85+
::Type{T}
86+
) where {T <: BitInteger} =
87+
(_fill_buffer!(reinterpret(UInt8, A), rng); A)
3888

src/keystream.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ julia> randn(stream)
4545
julia> randstring(stream, 'a':'z', 8)
4646
"klmptewr"
4747
```
48+
49+
```@meta
50+
DocTestSetup = nothing
51+
```
4852
"""
4953
mutable struct ChaChaStream <: AbstractChaChaStream
5054
key :: SVector{8,UInt32}

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ChaChaCiphers, Documenter, Test
33
include("test_core.jl")
44
include("test_chacha.jl")
55
include("test_keystream.jl")
6+
include("test_cuda_keystream.jl")
67

78
@testset "Package doctests" begin
89
doctest(ChaChaCiphers; manual=false)

test/test_cuda_keystream.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Tests for random number generation with
2+
# CUDAChaChaStream
3+
4+
using ChaChaCiphers
5+
using CUDA
6+
using Random
7+
using Statistics
8+
using Test
9+
10+
@testset "CUDAChaChaStream tests" begin
11+
if CUDA.functional()
12+
@testset "Construct CUDAChaChaStream" begin
13+
rng = CUDAChaCha12Stream()
14+
@test rng.doublerounds == 6
15+
16+
rng = CUDAChaCha20Stream()
17+
@test rng.doublerounds == 10
18+
end
19+
20+
@testset "Generate random strings" begin
21+
rng = CUDAChaChaStream(zeros(UInt32, 8), UInt64(0))
22+
x = randstring(rng, 'a':'c', 3 * 2^16)
23+
@test isa(x, String)
24+
@test length(x) == 3 * 2^16
25+
26+
counts = Dict((c => count(u -> u == c, x)) for c 'a':'c')
27+
counts = Dict((c => counts[c] / length(x)) for c 'a':'c')
28+
@test isapprox(counts['a'], 1/3, atol=1e-2)
29+
@test isapprox(counts['b'], 1/3, atol=1e-2)
30+
@test isapprox(counts['c'], 1/3, atol=1e-2)
31+
end
32+
33+
@testset "Sample uniform random numbers" begin
34+
rng = CUDAChaChaStream(zeros(UInt32, 8), UInt64(0))
35+
x = rand(rng, Float32, 100_000)
36+
37+
@test isa(x, Vector{Float32})
38+
@test size(x) == (100_000,)
39+
@test isapprox(mean(x), 0.5, atol=1e-2)
40+
41+
x = rand(rng, Float64, 400, 300)
42+
@test isa(x, Array{Float64,2})
43+
@test size(x) == (400, 300)
44+
@test isapprox(mean(x), 0.5, atol=1e-2)
45+
46+
# Generate random values directly inside of a pre-allocated array
47+
x_cpu = Vector{Float32}(undef, 100_000)
48+
Random.rand!(rng, x_cpu)
49+
@test isapprox(mean(x_cpu), 0.5, atol=1e-2)
50+
51+
x_cpu .= 0
52+
x_gpu = CuVector{Float32}(undef, 100_000)
53+
CUDA.@sync begin
54+
Random.rand!(rng, x_gpu)
55+
copyto!(x_cpu, x_gpu)
56+
end
57+
@test isapprox(mean(x_cpu), 0.5, atol=1e-2)
58+
end
59+
60+
@testset "Sample random normal numbers" begin
61+
rng = CUDAChaChaStream(zeros(UInt32, 8), UInt64(0))
62+
x = randn(rng, Float32, 100_000)
63+
64+
@test isa(x, Vector{Float32})
65+
@test size(x) == (100_000,)
66+
@test isapprox(mean(x), 0, atol=1e-2)
67+
@test isapprox(std(x), 1, atol=1e-2)
68+
69+
x = randn(rng, Float64, 100, 50, 50)
70+
@test isa(x, Array{Float64,3})
71+
@test size(x) == (100, 50, 50)
72+
@test isapprox(mean(x), 0, atol=1e-2)
73+
@test isapprox(std(x), 1, atol=1e-2)
74+
end
75+
76+
@testset "Save and restore keystream" begin
77+
# We should be able to save and restore a GPU
78+
# keystream as a CPU keystream, and vice-versa
79+
rng = CUDAChaChaStream(zeros(UInt32, 8), UInt64(0))
80+
state = getstate(rng)
81+
rng_cpu = ChaChaStream(state)
82+
83+
rand_gpu = rand(rng, 1_000)
84+
rand_cpu = rand(rng_cpu, 1_000)
85+
86+
@test rand_gpu == rand_cpu
87+
end
88+
else
89+
@warn "CUDA.functional() = false; skipping tests"
90+
end
91+
end

0 commit comments

Comments
 (0)