Skip to content

Commit 2f2f79e

Browse files
authored
Merge pull request #4 from kernelmethod/kernelmethod/simd
Improve SIMD parallelization
2 parents a362c8a + e1c250a commit 2f2f79e

File tree

5 files changed

+157
-65
lines changed

5 files changed

+157
-65
lines changed

Manifest.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ version = "1.3.0"
250250
[[deps.SHA]]
251251
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
252252

253+
[[deps.SIMD]]
254+
git-tree-sha1 = "7dbc15af7ed5f751a82bf3ed37757adf76c32402"
255+
uuid = "fdea26ae-647d-5447-a871-4b548cad5224"
256+
version = "3.4.1"
257+
253258
[[deps.Serialization]]
254259
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
255260

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ version = "0.1.0"
66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
910
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1011

1112
[compat]
1213
CUDA = "3.8"
14+
SIMD = "3.4"
1315
StaticArrays = "1.4"
1416
julia = "1"

src/ChaCha.jl

Lines changed: 107 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,48 @@ for CSPRNG.
55

66
module ChaCha
77

8-
using Core.Intrinsics: llvmcall
98
using CUDA
9+
using SIMD
1010
using StaticArrays
1111

1212
# ChaCha block size is 32 * 16 bits = 64 bytes
1313
const CHACHA_BLOCK_SIZE_U32 = 16
1414
const CHACHA_BLOCK_SIZE = div(32 * 16, 8)
1515

1616
@inline lrot32(x, n) = (x << n) | (x >> (32 - n))
17-
@inline lrot32(x::UInt32, n::UInt32) = llvmcall(
18-
("""
19-
declare i32 @llvm.fshl.i32(i32, i32, i32)
20-
define i32 @entry(i32, i32, i32) #0 {
21-
3:
22-
%res = call i32 @llvm.fshl.i32(i32 %0, i32 %0, i32 %1)
23-
ret i32 %res
24-
}
25-
attributes #0 = { alwaysinline }
26-
""", "entry"), UInt32, Tuple{UInt32, UInt32}, x, n)
27-
28-
@inline function _QR!(x, a, b, c, d)
29-
@inbounds begin
30-
x[a] += x[b]; x[d] ⊻= x[a]; x[d] = lrot32(x[d], UInt32(16))
31-
x[c] += x[d]; x[b] ⊻= x[c]; x[b] = lrot32(x[b], UInt32(12))
32-
x[a] += x[b]; x[d] ⊻= x[a]; x[d] = lrot32(x[d], UInt32(8))
33-
x[c] += x[d]; x[b] ⊻= x[c]; x[b] = lrot32(x[b], UInt32(7))
17+
@inline lrot32(x::Union{Vec,UInt32}, n) = bitrotate(x, n)
18+
19+
@inline @generated function rotatevector(x::Vec{N,T}, ::Val{M}) where {N,T,M}
20+
rotation = circshift(0:3, M)
21+
rotation = repeat(rotation, N ÷ 4)
22+
rotation += 4 * ((0:N-1) 4)
23+
rotation = Val(Tuple(rotation))
24+
:(shufflevector(x, $rotation))
25+
end
26+
27+
macro _QR!(a, b, c, d)
28+
quote
29+
$(esc(a)) += $(esc(b)); $(esc(d)) ⊻= $(esc(a)); $(esc(d)) = lrot32($(esc(d)), 16);
30+
$(esc(c)) += $(esc(d)); $(esc(b)) ⊻= $(esc(c)); $(esc(b)) = lrot32($(esc(b)), 12);
31+
$(esc(a)) += $(esc(b)); $(esc(d)) ⊻= $(esc(a)); $(esc(d)) = lrot32($(esc(d)), 8);
32+
$(esc(c)) += $(esc(d)); $(esc(b)) ⊻= $(esc(c)); $(esc(b)) = lrot32($(esc(b)), 7);
33+
34+
$(esc(a)), $(esc(b)), $(esc(c)), $(esc(d))
3435
end
3536
end
3637

3738
@inline function store_u64!(x::AbstractVector{UInt32}, u::UInt64, idx)
38-
x[idx] = UInt32(u & 0xffffffff)
39-
x[idx+1] = UInt32((u >> 32) & 0xffffffff)
39+
@inbounds begin
40+
x[idx] = UInt32(u & 0xffffffff)
41+
x[idx+1] = UInt32((u >> 32) & 0xffffffff)
42+
end
4043
end
4144

4245
@inline function add_u64!(x::AbstractVector{UInt32}, u::UInt64, idx)
43-
x[idx] += UInt32(u & 0xffffffff)
44-
x[idx+1] += UInt32((u >> 32) & 0xffffffff)
46+
@inbounds begin
47+
x[idx] += UInt32(u & 0xffffffff)
48+
x[idx+1] += UInt32((u >> 32) & 0xffffffff)
49+
end
4550
end
4651

4752
#=
@@ -144,40 +149,89 @@ function chacha_blocks!(
144149
nblocks = 1;
145150
doublerounds = 10,
146151
)
147-
for i 1:nblocks
148-
block_start = CHACHA_BLOCK_SIZE_U32 * (i - 1) + 1
149-
block_end = block_start + CHACHA_BLOCK_SIZE_U32 - 1
150-
state = view(buffer, block_start:block_end)
151-
152-
_chacha_set_initial_state!(state, key, nonce, counter, 1)
153-
154-
# Perform alternating rounds of columnar
155-
# quarter-rounds and diagonal quarter-rounds
156-
for i = 1:doublerounds
157-
# Columnar rounds
158-
_QR!(state, 1, 5, 9, 13)
159-
_QR!(state, 2, 6, 10, 14)
160-
_QR!(state, 3, 7, 11, 15)
161-
_QR!(state, 4, 8, 12, 16)
162-
163-
# Diagonal rounds
164-
_QR!(state, 1, 6, 11, 16)
165-
_QR!(state, 2, 7, 12, 13)
166-
_QR!(state, 3, 8, 9, 14)
167-
_QR!(state, 4, 5, 10, 15)
168-
end
169-
170-
# Finish by adding the initial state back to
171-
# the original state, so that the operations
172-
# are no longer invertible
173-
_chacha_add_initial_state!(state, key, nonce, counter, 1)
152+
block_start = 1
153+
154+
# We compute as many blocks of output as possible with 512-bit
155+
# SIMD vectorization
156+
for i 1:4:nblocks-3
157+
block_start, counter = _chacha_blocks!(
158+
buffer, block_start, key, nonce, counter, doublerounds, Val(4)
159+
)
160+
end
174161

175-
counter += 1
162+
# The remaining blocks are computed with 128-bit vectorization
163+
for i 1:(nblocks % 4)
164+
block_start, counter = _chacha_blocks!(
165+
buffer, block_start, key, nonce, counter, doublerounds, Val(1)
166+
)
176167
end
177168

178169
counter
179170
end
180171

172+
# Compute the ChaCha block function with N * 128-bit SIMD vectorization
173+
#
174+
# Reference: https://eprint.iacr.org/2013/759.pdf
175+
@inline function _chacha_blocks!(
176+
buffer::AbstractVector{UInt32}, block_start, key, nonce, counter, doublerounds, ::Val{N}
177+
) where N
178+
block_end = block_start + N * CHACHA_BLOCK_SIZE_U32 - 1
179+
@inbounds state = view(buffer, block_start:block_end)
180+
181+
for i = 0:N-1
182+
_chacha_set_initial_state!(state, key, nonce, counter + i, i * CHACHA_BLOCK_SIZE_U32 + 1)
183+
end
184+
185+
_chacha_rounds!(state, doublerounds, Val(N))
186+
187+
for i = 0:N-1
188+
_chacha_add_initial_state!(state, key, nonce, counter + i, i * CHACHA_BLOCK_SIZE_U32 + 1)
189+
end
190+
191+
block_end + 1, counter + N
192+
end
193+
194+
195+
@inline @generated function _chacha_rounds!(state, doublerounds, ::Val{N}) where N
196+
# Perform alternating rounds of columnar
197+
# quarter-rounds and diagonal quarter-rounds
198+
lane = (1, 2, 3, 4)
199+
lane = repeat(1:4, N)
200+
lane += 16 * ((0:4*N-1) 4)
201+
lane = Tuple(lane)
202+
203+
idx0 = Vec(lane)
204+
idx1 = Vec(lane .+ 4)
205+
idx2 = Vec(lane .+ 8)
206+
idx3 = Vec(lane .+ 12)
207+
208+
quote
209+
@inbounds begin
210+
v0 = vgather(state, $idx0)
211+
v1 = vgather(state, $idx1)
212+
v2 = vgather(state, $idx2)
213+
v3 = vgather(state, $idx3)
214+
215+
for i = 1:doublerounds
216+
v0, v1, v2, v3 = @_QR!(v0, v1, v2, v3)
217+
v1 = rotatevector(v1, Val(-1))
218+
v2 = rotatevector(v2, Val(-2))
219+
v3 = rotatevector(v3, Val(-3))
220+
221+
v0, v1, v2, v3 = @_QR!(v0, v1, v2, v3)
222+
v1 = rotatevector(v1, Val(1))
223+
v2 = rotatevector(v2, Val(2))
224+
v3 = rotatevector(v3, Val(3))
225+
end
226+
227+
vscatter(v0, state, $idx0)
228+
vscatter(v1, state, $idx1)
229+
vscatter(v2, state, $idx2)
230+
vscatter(v3, state, $idx3)
231+
end
232+
end
233+
end
234+
181235
function chacha_blocks!(
182236
buffer::CuArray, key, nonce::UInt64, counter::UInt64, nblocks = 1; doublerounds = 10
183237
)
@@ -204,7 +258,7 @@ function _cuda_chacha_rounds!(state, doublerounds)
204258

205259
# Only operate on a slice of the state corresponding to
206260
# the thread block
207-
state_slice = view(state, block+1:block+16)
261+
slice = view(state, block+1:block+16)
208262

209263
# Pre-compute the indices that this thread will use to
210264
# perform its diagonal rounds
@@ -219,11 +273,11 @@ function _cuda_chacha_rounds!(state, doublerounds)
219273
# Each thread in the same block runs its rounds in parallel
220274
for _ = 1:doublerounds
221275
# Columnar rounds
222-
_QR!(state_slice, i, i + 4, i + 8, i + 12)
276+
@_QR!(slice[i], slice[i+4], slice[i+8], slice[i+12])
223277
CUDA.threadfence_block()
224278

225279
# Diagonal rounds
226-
_QR!(state_slice, dgc1, dgc2, dgc3, dgc4)
280+
@_QR!(slice[dgc1], slice[dgc2], slice[dgc3], slice[dgc4])
227281
CUDA.threadfence_block()
228282
end
229283

src/keystream.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,17 +122,12 @@ end
122122
function _fill_blocks!(
123123
buffer::AbstractVector{T}, stream::ChaChaStream, nblocks::Int
124124
) where {T <: BitInteger}
125-
bufsize_u32 = div(length(buffer) * sizeof(T), sizeof(UInt32))
125+
bufsize_u32 = sizeof(buffer) ÷ sizeof(UInt32)
126126

127127
GC.@preserve buffer begin
128-
# Create a pointer to the start of the block,
129-
# and wrap it in an instance of UnsafeView.
130-
#
131-
# This provides a decent speedup over using
132-
# reinterpret(UInt32, ...)
133128
bp = pointer(buffer)
134129
bp = Base.unsafe_convert(Ptr{UInt32}, bp)
135-
bufview = UnsafeView(bp, bufsize_u32)
130+
bufview = unsafe_wrap(Vector{UInt32}, bp, bufsize_u32)
136131

137132
stream.counter = chacha_blocks!(
138133
bufview,

test/test_chacha.jl

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using StaticArrays
1313
using Test
1414

1515
function chacha_blocks_test_suite(T)
16-
@testset "Test chacha_blocks!" begin
16+
@testset "RFC 8439 ChaCha block function tests" begin
1717
# Ref: IETF RFC 8439, Sec. 2.3.2
1818
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.3.2
1919
key = SVector{8,UInt32}([
@@ -114,14 +114,50 @@ function chacha_blocks_test_suite(T)
114114
@test state == test_vector
115115
end
116116

117+
@testset "Extended ChaCha block function tests" begin
118+
# Run multiple blocks of ChaCha with key, counter, and nonce equal
119+
# to zero
120+
#
121+
# It's more efficient to compute multiple blocks in parallel on both
122+
# CPU and GPU, so this test ensures that parallelization doesn't
123+
# introduce any new errors.
124+
key = SVector{8,UInt32}(zeros(UInt32, 8)) |> T
125+
nonce = UInt64(0)
126+
counter = UInt64(0)
127+
test_vector = SVector{64,UInt32}([
128+
# Block 1
129+
0xade0b876, 0x903df1a0, 0xe56a5d40, 0x28bd8653,
130+
0xb819d2bd, 0x1aed8da0, 0xccef36a8, 0xc70d778b,
131+
0x7c5941da, 0x8d485751, 0x3fe02477, 0x374ad8b8,
132+
0xf4b8436a, 0x1ca11815, 0x69b687c3, 0x8665eeb2,
133+
# Block 2
134+
0xbee7079f, 0x7a385155, 0x7c97ba98, 0x0d082d73,
135+
0xa0290fcb, 0x6965e348, 0x3e53c612, 0xed7aee32,
136+
0x7621b729, 0x434ee69c, 0xb03371d5, 0xd539d874,
137+
0x281fed31, 0x45fb0a51, 0x1f0ae1ac, 0x6f4d794b,
138+
# Block 3
139+
0xe6a0092d, 0xe16c2663, 0x08d17eae, 0x75a06819,
140+
0x998e718e, 0xc662d37b, 0x3446c3b0, 0x5db3a0a9,
141+
0x68372701, 0x0f5d7b1f, 0xfd3a1e28, 0x1ebc58e4,
142+
0x13d3d273, 0xc094cfc9, 0x6271f35f, 0xf248a240,
143+
# Block 4
144+
0x58a02013, 0x6b56b3d7, 0xaada20d5, 0x0abfd23e,
145+
0x20b1b8c5, 0x732785fb, 0x349763c3, 0xa4915cb4,
146+
0x83cbd42d, 0x2e0d84f8, 0x1358b1ed, 0x3fac6210,
147+
0xfff82c1f, 0x5618cd6d, 0x6c1e6ae8, 0x7e166731
148+
]) |> T
149+
state = MVector{64,UInt32}(undef) |> T
150+
@test ChaCha.chacha_blocks!(state, key, nonce, counter, 4) == counter + 4
151+
@test state == test_vector
152+
end
117153
end
118154

119155
@testset "ChaCha tests" begin
120156
@testset "Quarter-round function tests" begin
121157
# Ref: IETF RFC 8439, Sec. 2.1.1
122158
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.1.1
123159
state = MVector{4,UInt32}([0x11111111, 0x01020304, 0x9b8d6f43, 0x01234567])
124-
ChaCha._QR!(state, 1, 2, 3, 4)
160+
ChaCha.@_QR!(state[1], state[2], state[3], state[4])
125161

126162
expected_state = SVector{4,UInt32}([0xea2a92f4, 0xcb1cf8ce, 0x4581472e, 0x5881c4bb])
127163

@@ -138,7 +174,7 @@ end
138174
])
139175
initial_state = deepcopy(state)
140176

141-
ChaCha._QR!(state, 3, 8, 9, 14)
177+
ChaCha.@_QR!(state[3], state[8], state[9], state[14])
142178

143179
mask = trues(length(state))
144180
mask[3] = mask[8] = mask[9] = mask[14] = false
@@ -188,11 +224,11 @@ end
188224

189225
function kernel(state, a, b, c, d)
190226
i = 4 * (threadIdx().x - 1)
191-
ChaCha._QR!(state, i + a, i + b, i + c, i + d)
227+
ChaCha.@_QR!(state[i+a], state[i+b], state[i+c], state[i+d])
192228
nothing
193229
end
194230

195-
ChaCha._QR!(state, 1, 2, 3, 4)
231+
ChaCha.@_QR!(state[1], state[2], state[3], state[4])
196232
CUDA.@sync @cuda threads=1024 kernel(state_gpu, 1, 2, 3, 4)
197233

198234
@test state_gpu == CuArray(collect(repeat(state, 1024)))

0 commit comments

Comments
 (0)