@@ -5,43 +5,48 @@ for CSPRNG.
5
5
6
6
module ChaCha
7
7
8
- using Core. Intrinsics: llvmcall
9
8
using CUDA
9
+ using SIMD
10
10
using StaticArrays
11
11
12
12
# ChaCha block size is 32 * 16 bits = 64 bytes
13
13
const CHACHA_BLOCK_SIZE_U32 = 16
14
14
const CHACHA_BLOCK_SIZE = div (32 * 16 , 8 )
15
15
16
16
@inline lrot32 (x, n) = (x << n) | (x >> (32 - n))
17
- @inline lrot32 (x:: UInt32 , n:: UInt32 ) = llvmcall (
18
- ("""
19
- declare i32 @llvm.fshl.i32(i32, i32, i32)
20
- define i32 @entry(i32, i32, i32) #0 {
21
- 3:
22
- %res = call i32 @llvm.fshl.i32(i32 %0, i32 %0, i32 %1)
23
- ret i32 %res
24
- }
25
- attributes #0 = { alwaysinline }
26
- """ , " entry" ), UInt32, Tuple{UInt32, UInt32}, x, n)
27
-
28
- @inline function _QR! (x, a, b, c, d)
29
- @inbounds begin
30
- x[a] += x[b]; x[d] ⊻= x[a]; x[d] = lrot32 (x[d], UInt32 (16 ))
31
- x[c] += x[d]; x[b] ⊻= x[c]; x[b] = lrot32 (x[b], UInt32 (12 ))
32
- x[a] += x[b]; x[d] ⊻= x[a]; x[d] = lrot32 (x[d], UInt32 (8 ))
33
- x[c] += x[d]; x[b] ⊻= x[c]; x[b] = lrot32 (x[b], UInt32 (7 ))
17
+ @inline lrot32 (x:: Union{Vec,UInt32} , n) = bitrotate (x, n)
18
+
19
+ @inline @generated function rotatevector (x:: Vec{N,T} , :: Val{M} ) where {N,T,M}
20
+ rotation = circshift (0 : 3 , M)
21
+ rotation = repeat (rotation, N ÷ 4 )
22
+ rotation += 4 * ((0 : N- 1 ) .÷ 4 )
23
+ rotation = Val (Tuple (rotation))
24
+ :(shufflevector (x, $ rotation))
25
+ end
26
+
27
+ macro _QR! (a, b, c, d)
28
+ quote
29
+ $ (esc (a)) += $ (esc (b)); $ (esc (d)) ⊻= $ (esc (a)); $ (esc (d)) = lrot32 ($ (esc (d)), 16 );
30
+ $ (esc (c)) += $ (esc (d)); $ (esc (b)) ⊻= $ (esc (c)); $ (esc (b)) = lrot32 ($ (esc (b)), 12 );
31
+ $ (esc (a)) += $ (esc (b)); $ (esc (d)) ⊻= $ (esc (a)); $ (esc (d)) = lrot32 ($ (esc (d)), 8 );
32
+ $ (esc (c)) += $ (esc (d)); $ (esc (b)) ⊻= $ (esc (c)); $ (esc (b)) = lrot32 ($ (esc (b)), 7 );
33
+
34
+ $ (esc (a)), $ (esc (b)), $ (esc (c)), $ (esc (d))
34
35
end
35
36
end
36
37
37
38
@inline function store_u64! (x:: AbstractVector{UInt32} , u:: UInt64 , idx)
38
- x[idx] = UInt32 (u & 0xffffffff )
39
- x[idx+ 1 ] = UInt32 ((u >> 32 ) & 0xffffffff )
39
+ @inbounds begin
40
+ x[idx] = UInt32 (u & 0xffffffff )
41
+ x[idx+ 1 ] = UInt32 ((u >> 32 ) & 0xffffffff )
42
+ end
40
43
end
41
44
42
45
@inline function add_u64! (x:: AbstractVector{UInt32} , u:: UInt64 , idx)
43
- x[idx] += UInt32 (u & 0xffffffff )
44
- x[idx+ 1 ] += UInt32 ((u >> 32 ) & 0xffffffff )
46
+ @inbounds begin
47
+ x[idx] += UInt32 (u & 0xffffffff )
48
+ x[idx+ 1 ] += UInt32 ((u >> 32 ) & 0xffffffff )
49
+ end
45
50
end
46
51
47
52
#=
@@ -144,40 +149,89 @@ function chacha_blocks!(
144
149
nblocks = 1 ;
145
150
doublerounds = 10 ,
146
151
)
147
- for i ∈ 1 : nblocks
148
- block_start = CHACHA_BLOCK_SIZE_U32 * (i - 1 ) + 1
149
- block_end = block_start + CHACHA_BLOCK_SIZE_U32 - 1
150
- state = view (buffer, block_start: block_end)
151
-
152
- _chacha_set_initial_state! (state, key, nonce, counter, 1 )
153
-
154
- # Perform alternating rounds of columnar
155
- # quarter-rounds and diagonal quarter-rounds
156
- for i = 1 : doublerounds
157
- # Columnar rounds
158
- _QR! (state, 1 , 5 , 9 , 13 )
159
- _QR! (state, 2 , 6 , 10 , 14 )
160
- _QR! (state, 3 , 7 , 11 , 15 )
161
- _QR! (state, 4 , 8 , 12 , 16 )
162
-
163
- # Diagonal rounds
164
- _QR! (state, 1 , 6 , 11 , 16 )
165
- _QR! (state, 2 , 7 , 12 , 13 )
166
- _QR! (state, 3 , 8 , 9 , 14 )
167
- _QR! (state, 4 , 5 , 10 , 15 )
168
- end
169
-
170
- # Finish by adding the initial state back to
171
- # the original state, so that the operations
172
- # are no longer invertible
173
- _chacha_add_initial_state! (state, key, nonce, counter, 1 )
152
+ block_start = 1
153
+
154
+ # We compute as many blocks of output as possible with 512-bit
155
+ # SIMD vectorization
156
+ for i ∈ 1 : 4 : nblocks- 3
157
+ block_start, counter = _chacha_blocks! (
158
+ buffer, block_start, key, nonce, counter, doublerounds, Val (4 )
159
+ )
160
+ end
174
161
175
- counter += 1
162
+ # The remaining blocks are computed with 128-bit vectorization
163
+ for i ∈ 1 : (nblocks % 4 )
164
+ block_start, counter = _chacha_blocks! (
165
+ buffer, block_start, key, nonce, counter, doublerounds, Val (1 )
166
+ )
176
167
end
177
168
178
169
counter
179
170
end
180
171
172
+ # Compute the ChaCha block function with N * 128-bit SIMD vectorization
173
+ #
174
+ # Reference: https://eprint.iacr.org/2013/759.pdf
175
+ @inline function _chacha_blocks! (
176
+ buffer:: AbstractVector{UInt32} , block_start, key, nonce, counter, doublerounds, :: Val{N}
177
+ ) where N
178
+ block_end = block_start + N * CHACHA_BLOCK_SIZE_U32 - 1
179
+ @inbounds state = view (buffer, block_start: block_end)
180
+
181
+ for i = 0 : N- 1
182
+ _chacha_set_initial_state! (state, key, nonce, counter + i, i * CHACHA_BLOCK_SIZE_U32 + 1 )
183
+ end
184
+
185
+ _chacha_rounds! (state, doublerounds, Val (N))
186
+
187
+ for i = 0 : N- 1
188
+ _chacha_add_initial_state! (state, key, nonce, counter + i, i * CHACHA_BLOCK_SIZE_U32 + 1 )
189
+ end
190
+
191
+ block_end + 1 , counter + N
192
+ end
193
+
194
+
195
+ @inline @generated function _chacha_rounds! (state, doublerounds, :: Val{N} ) where N
196
+ # Perform alternating rounds of columnar
197
+ # quarter-rounds and diagonal quarter-rounds
198
+ lane = (1 , 2 , 3 , 4 )
199
+ lane = repeat (1 : 4 , N)
200
+ lane += 16 * ((0 : 4 * N- 1 ) .÷ 4 )
201
+ lane = Tuple (lane)
202
+
203
+ idx0 = Vec (lane)
204
+ idx1 = Vec (lane .+ 4 )
205
+ idx2 = Vec (lane .+ 8 )
206
+ idx3 = Vec (lane .+ 12 )
207
+
208
+ quote
209
+ @inbounds begin
210
+ v0 = vgather (state, $ idx0)
211
+ v1 = vgather (state, $ idx1)
212
+ v2 = vgather (state, $ idx2)
213
+ v3 = vgather (state, $ idx3)
214
+
215
+ for i = 1 : doublerounds
216
+ v0, v1, v2, v3 = @_QR! (v0, v1, v2, v3)
217
+ v1 = rotatevector (v1, Val (- 1 ))
218
+ v2 = rotatevector (v2, Val (- 2 ))
219
+ v3 = rotatevector (v3, Val (- 3 ))
220
+
221
+ v0, v1, v2, v3 = @_QR! (v0, v1, v2, v3)
222
+ v1 = rotatevector (v1, Val (1 ))
223
+ v2 = rotatevector (v2, Val (2 ))
224
+ v3 = rotatevector (v3, Val (3 ))
225
+ end
226
+
227
+ vscatter (v0, state, $ idx0)
228
+ vscatter (v1, state, $ idx1)
229
+ vscatter (v2, state, $ idx2)
230
+ vscatter (v3, state, $ idx3)
231
+ end
232
+ end
233
+ end
234
+
181
235
function chacha_blocks! (
182
236
buffer:: CuArray , key, nonce:: UInt64 , counter:: UInt64 , nblocks = 1 ; doublerounds = 10
183
237
)
@@ -204,7 +258,7 @@ function _cuda_chacha_rounds!(state, doublerounds)
204
258
205
259
# Only operate on a slice of the state corresponding to
206
260
# the thread block
207
- state_slice = view (state, block+ 1 : block+ 16 )
261
+ slice = view (state, block+ 1 : block+ 16 )
208
262
209
263
# Pre-compute the indices that this thread will use to
210
264
# perform its diagonal rounds
@@ -219,11 +273,11 @@ function _cuda_chacha_rounds!(state, doublerounds)
219
273
# Each thread in the same block runs its rounds in parallel
220
274
for _ = 1 : doublerounds
221
275
# Columnar rounds
222
- _QR! (state_slice, i, i + 4 , i + 8 , i + 12 )
276
+ @ _QR! (slice[i], slice[i + 4 ], slice[i + 8 ], slice[i + 12 ] )
223
277
CUDA. threadfence_block ()
224
278
225
279
# Diagonal rounds
226
- _QR! (state_slice, dgc1, dgc2, dgc3, dgc4)
280
+ @ _QR! (slice[ dgc1], slice[ dgc2], slice[ dgc3], slice[ dgc4] )
227
281
CUDA. threadfence_block ()
228
282
end
229
283
0 commit comments