Skip to content

Commit d28f74f

Browse files
committed
Improve SIMD vectorization of the ChaCha block function on CPU
Using some techniques borrowed from [1], I've made some improvements to the instruction-level parallelism of the ChaCha block function for CPU. The block function now uses AVX-512 instructions to compute as many ChaCha blocks as possible, and then uses 128-bit vectorization to compute the remaining blocks. In my tests this gives a pretty big speedup for CPU-based random number generation. [1] https://eprint.iacr.org/2013/759.pdf
1 parent a362c8a commit d28f74f

File tree

5 files changed

+149
-51
lines changed

5 files changed

+149
-51
lines changed

Manifest.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ version = "1.3.0"
250250
[[deps.SHA]]
251251
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
252252

253+
[[deps.SIMD]]
254+
git-tree-sha1 = "7dbc15af7ed5f751a82bf3ed37757adf76c32402"
255+
uuid = "fdea26ae-647d-5447-a871-4b548cad5224"
256+
version = "3.4.1"
257+
253258
[[deps.Serialization]]
254259
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
255260

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ version = "0.1.0"
66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
910
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1011

1112
[compat]
1213
CUDA = "3.8"
14+
SIMD = "3.4"
1315
StaticArrays = "1.4"
1416
julia = "1"

src/ChaCha.jl

Lines changed: 103 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,24 @@ for CSPRNG.
55

66
module ChaCha
77

8-
using Core.Intrinsics: llvmcall
98
using CUDA
9+
using SIMD
1010
using StaticArrays
1111

1212
# ChaCha block size is 32 * 16 bits = 64 bytes
1313
const CHACHA_BLOCK_SIZE_U32 = 16
1414
const CHACHA_BLOCK_SIZE = div(32 * 16, 8)
1515

1616
@inline lrot32(x, n) = (x << n) | (x >> (32 - n))
17-
@inline lrot32(x::UInt32, n::UInt32) = llvmcall(
18-
("""
19-
declare i32 @llvm.fshl.i32(i32, i32, i32)
20-
define i32 @entry(i32, i32, i32) #0 {
21-
3:
22-
%res = call i32 @llvm.fshl.i32(i32 %0, i32 %0, i32 %1)
23-
ret i32 %res
24-
}
25-
attributes #0 = { alwaysinline }
26-
""", "entry"), UInt32, Tuple{UInt32, UInt32}, x, n)
17+
@inline lrot32(x::Union{Vec,UInt32}, n) = bitrotate(x, n)
18+
19+
@inline @generated function rotatevector(x::Vec{N,T}, ::Val{M}) where {N,T,M}
20+
rotation = circshift(0:3, M)
21+
rotation = repeat(rotation, N ÷ 4)
22+
rotation += 4 * ((0:N-1) 4)
23+
rotation = Val(Tuple(rotation))
24+
:(shufflevector(x, $rotation))
25+
end
2726

2827
@inline function _QR!(x, a, b, c, d)
2928
@inbounds begin
@@ -34,14 +33,26 @@ const CHACHA_BLOCK_SIZE = div(32 * 16, 8)
3433
end
3534
end
3635

36+
@inline function _QR!(a, b, c, d)
37+
a += b; d ⊻= a; d = lrot32(d, UInt32(16));
38+
c += d; b ⊻= c; b = lrot32(b, UInt32(12));
39+
a += b; d ⊻= a; d = lrot32(d, UInt32(8));
40+
c += d; b ⊻= c; b = lrot32(b, UInt32(7));
41+
a, b, c, d
42+
end
43+
3744
@inline function store_u64!(x::AbstractVector{UInt32}, u::UInt64, idx)
38-
x[idx] = UInt32(u & 0xffffffff)
39-
x[idx+1] = UInt32((u >> 32) & 0xffffffff)
45+
@inbounds begin
46+
x[idx] = UInt32(u & 0xffffffff)
47+
x[idx+1] = UInt32((u >> 32) & 0xffffffff)
48+
end
4049
end
4150

4251
@inline function add_u64!(x::AbstractVector{UInt32}, u::UInt64, idx)
43-
x[idx] += UInt32(u & 0xffffffff)
44-
x[idx+1] += UInt32((u >> 32) & 0xffffffff)
52+
@inbounds begin
53+
x[idx] += UInt32(u & 0xffffffff)
54+
x[idx+1] += UInt32((u >> 32) & 0xffffffff)
55+
end
4556
end
4657

4758
#=
@@ -144,40 +155,89 @@ function chacha_blocks!(
144155
nblocks = 1;
145156
doublerounds = 10,
146157
)
147-
for i 1:nblocks
148-
block_start = CHACHA_BLOCK_SIZE_U32 * (i - 1) + 1
149-
block_end = block_start + CHACHA_BLOCK_SIZE_U32 - 1
150-
state = view(buffer, block_start:block_end)
151-
152-
_chacha_set_initial_state!(state, key, nonce, counter, 1)
153-
154-
# Perform alternating rounds of columnar
155-
# quarter-rounds and diagonal quarter-rounds
156-
for i = 1:doublerounds
157-
# Columnar rounds
158-
_QR!(state, 1, 5, 9, 13)
159-
_QR!(state, 2, 6, 10, 14)
160-
_QR!(state, 3, 7, 11, 15)
161-
_QR!(state, 4, 8, 12, 16)
162-
163-
# Diagonal rounds
164-
_QR!(state, 1, 6, 11, 16)
165-
_QR!(state, 2, 7, 12, 13)
166-
_QR!(state, 3, 8, 9, 14)
167-
_QR!(state, 4, 5, 10, 15)
168-
end
169-
170-
# Finish by adding the initial state back to
171-
# the original state, so that the operations
172-
# are no longer invertible
173-
_chacha_add_initial_state!(state, key, nonce, counter, 1)
158+
block_start = 1
159+
160+
# We compute as many blocks of output as possible with 512-bit
161+
# SIMD vectorization
162+
for i 1:4:nblocks-3
163+
block_start, counter = _chacha_blocks!(
164+
buffer, block_start, key, nonce, counter, doublerounds, Val(4)
165+
)
166+
end
174167

175-
counter += 1
168+
# The remaining blocks are computed with 128-bit vectorization
169+
for i 1:(nblocks % 4)
170+
block_start, counter = _chacha_blocks!(
171+
buffer, block_start, key, nonce, counter, doublerounds, Val(1)
172+
)
176173
end
177174

178175
counter
179176
end
180177

178+
# Compute the ChaCha block function with N * 128-bit SIMD vectorization
179+
#
180+
# Reference: https://eprint.iacr.org/2013/759.pdf
181+
@inline function _chacha_blocks!(
182+
buffer::AbstractVector{UInt32}, block_start, key, nonce, counter, doublerounds, ::Val{N}
183+
) where N
184+
block_end = block_start + N * CHACHA_BLOCK_SIZE_U32 - 1
185+
@inbounds state = view(buffer, block_start:block_end)
186+
187+
for i = 0:N-1
188+
_chacha_set_initial_state!(state, key, nonce, counter + i, i * CHACHA_BLOCK_SIZE_U32 + 1)
189+
end
190+
191+
_chacha_rounds!(state, doublerounds, Val(N))
192+
193+
for i = 0:N-1
194+
_chacha_add_initial_state!(state, key, nonce, counter + i, i * CHACHA_BLOCK_SIZE_U32 + 1)
195+
end
196+
197+
block_end + 1, counter + N
198+
end
199+
200+
201+
@inline @generated function _chacha_rounds!(state, doublerounds, ::Val{N}) where N
202+
# Perform alternating rounds of columnar
203+
# quarter-rounds and diagonal quarter-rounds
204+
lane = (1, 2, 3, 4)
205+
lane = repeat(1:4, N)
206+
lane += 16 * ((0:4*N-1) 4)
207+
lane = Tuple(lane)
208+
209+
idx0 = Vec(lane)
210+
idx1 = Vec(lane .+ 4)
211+
idx2 = Vec(lane .+ 8)
212+
idx3 = Vec(lane .+ 12)
213+
214+
quote
215+
@inbounds begin
216+
v0 = vgather(state, $idx0)
217+
v1 = vgather(state, $idx1)
218+
v2 = vgather(state, $idx2)
219+
v3 = vgather(state, $idx3)
220+
221+
for i = 1:doublerounds
222+
v0, v1, v2, v3 = _QR!(v0, v1, v2, v3)
223+
v1 = rotatevector(v1, Val(-1))
224+
v2 = rotatevector(v2, Val(-2))
225+
v3 = rotatevector(v3, Val(-3))
226+
227+
v0, v1, v2, v3 = _QR!(v0, v1, v2, v3)
228+
v1 = rotatevector(v1, Val(1))
229+
v2 = rotatevector(v2, Val(2))
230+
v3 = rotatevector(v3, Val(3))
231+
end
232+
233+
vscatter(v0, state, $idx0)
234+
vscatter(v1, state, $idx1)
235+
vscatter(v2, state, $idx2)
236+
vscatter(v3, state, $idx3)
237+
end
238+
end
239+
end
240+
181241
function chacha_blocks!(
182242
buffer::CuArray, key, nonce::UInt64, counter::UInt64, nblocks = 1; doublerounds = 10
183243
)

src/keystream.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,17 +122,12 @@ end
122122
function _fill_blocks!(
123123
buffer::AbstractVector{T}, stream::ChaChaStream, nblocks::Int
124124
) where {T <: BitInteger}
125-
bufsize_u32 = div(length(buffer) * sizeof(T), sizeof(UInt32))
125+
bufsize_u32 = sizeof(buffer) ÷ sizeof(UInt32)
126126

127127
GC.@preserve buffer begin
128-
# Create a pointer to the start of the block,
129-
# and wrap it in an instance of UnsafeView.
130-
#
131-
# This provides a decent speedup over using
132-
# reinterpret(UInt32, ...)
133128
bp = pointer(buffer)
134129
bp = Base.unsafe_convert(Ptr{UInt32}, bp)
135-
bufview = UnsafeView(bp, bufsize_u32)
130+
bufview = unsafe_wrap(Vector{UInt32}, bp, bufsize_u32)
136131

137132
stream.counter = chacha_blocks!(
138133
bufview,

test/test_chacha.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using StaticArrays
1313
using Test
1414

1515
function chacha_blocks_test_suite(T)
16-
@testset "Test chacha_blocks!" begin
16+
@testset "RFC 8439 ChaCha block function tests" begin
1717
# Ref: IETF RFC 8439, Sec. 2.3.2
1818
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.3.2
1919
key = SVector{8,UInt32}([
@@ -114,6 +114,42 @@ function chacha_blocks_test_suite(T)
114114
@test state == test_vector
115115
end
116116

117+
@testset "Extended ChaCha block function tests" begin
118+
# Run multiple blocks of ChaCha with key, counter, and nonce equal
119+
# to zero
120+
#
121+
# It's more efficient to compute multiple blocks in parallel on both
122+
# CPU and GPU, so this test ensures that parallelization doesn't
123+
# introduce any new errors.
124+
key = SVector{8,UInt32}(zeros(UInt32, 8)) |> T
125+
nonce = UInt64(0)
126+
counter = UInt64(0)
127+
test_vector = SVector{64,UInt32}([
128+
# Block 1
129+
0xade0b876, 0x903df1a0, 0xe56a5d40, 0x28bd8653,
130+
0xb819d2bd, 0x1aed8da0, 0xccef36a8, 0xc70d778b,
131+
0x7c5941da, 0x8d485751, 0x3fe02477, 0x374ad8b8,
132+
0xf4b8436a, 0x1ca11815, 0x69b687c3, 0x8665eeb2,
133+
# Block 2
134+
0xbee7079f, 0x7a385155, 0x7c97ba98, 0x0d082d73,
135+
0xa0290fcb, 0x6965e348, 0x3e53c612, 0xed7aee32,
136+
0x7621b729, 0x434ee69c, 0xb03371d5, 0xd539d874,
137+
0x281fed31, 0x45fb0a51, 0x1f0ae1ac, 0x6f4d794b,
138+
# Block 3
139+
0xe6a0092d, 0xe16c2663, 0x08d17eae, 0x75a06819,
140+
0x998e718e, 0xc662d37b, 0x3446c3b0, 0x5db3a0a9,
141+
0x68372701, 0x0f5d7b1f, 0xfd3a1e28, 0x1ebc58e4,
142+
0x13d3d273, 0xc094cfc9, 0x6271f35f, 0xf248a240,
143+
# Block 4
144+
0x58a02013, 0x6b56b3d7, 0xaada20d5, 0x0abfd23e,
145+
0x20b1b8c5, 0x732785fb, 0x349763c3, 0xa4915cb4,
146+
0x83cbd42d, 0x2e0d84f8, 0x1358b1ed, 0x3fac6210,
147+
0xfff82c1f, 0x5618cd6d, 0x6c1e6ae8, 0x7e166731
148+
]) |> T
149+
state = MVector{64,UInt32}(undef) |> T
150+
@test ChaCha.chacha_blocks!(state, key, nonce, counter, 4) == counter + 4
151+
@test state == test_vector
152+
end
117153
end
118154

119155
@testset "ChaCha tests" begin

0 commit comments

Comments
 (0)