Skip to content

Commit e1c250a

Browse files
committed
Convert the quarter-round function into a macro
Convert _QR! from an inlined function into a macro, to re-unify some of the CPU and GPU ChaCha code. The previous changes to explicitly vectorize operations with SIMD required us to define a new version of _QR! and dispatch on it. By converting _QR! into a macro, we can use the same code to represent a quarter-round on either GPU arrays or CPU ones.
1 parent d28f74f commit e1c250a

File tree

2 files changed

+17
-23
lines changed

2 files changed

+17
-23
lines changed

src/ChaCha.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,15 @@ const CHACHA_BLOCK_SIZE = div(32 * 16, 8)
2424
:(shufflevector(x, $rotation))
2525
end
2626

27-
@inline function _QR!(x, a, b, c, d)
28-
@inbounds begin
29-
x[a] += x[b]; x[d] ⊻= x[a]; x[d] = lrot32(x[d], UInt32(16))
30-
x[c] += x[d]; x[b] ⊻= x[c]; x[b] = lrot32(x[b], UInt32(12))
31-
x[a] += x[b]; x[d] ⊻= x[a]; x[d] = lrot32(x[d], UInt32(8))
32-
x[c] += x[d]; x[b] ⊻= x[c]; x[b] = lrot32(x[b], UInt32(7))
33-
end
34-
end
27+
macro _QR!(a, b, c, d)
28+
quote
29+
$(esc(a)) += $(esc(b)); $(esc(d)) ⊻= $(esc(a)); $(esc(d)) = lrot32($(esc(d)), 16);
30+
$(esc(c)) += $(esc(d)); $(esc(b)) ⊻= $(esc(c)); $(esc(b)) = lrot32($(esc(b)), 12);
31+
$(esc(a)) += $(esc(b)); $(esc(d)) ⊻= $(esc(a)); $(esc(d)) = lrot32($(esc(d)), 8);
32+
$(esc(c)) += $(esc(d)); $(esc(b)) ⊻= $(esc(c)); $(esc(b)) = lrot32($(esc(b)), 7);
3533

36-
@inline function _QR!(a, b, c, d)
37-
a += b; d ⊻= a; d = lrot32(d, UInt32(16));
38-
c += d; b ⊻= c; b = lrot32(b, UInt32(12));
39-
a += b; d ⊻= a; d = lrot32(d, UInt32(8));
40-
c += d; b ⊻= c; b = lrot32(b, UInt32(7));
41-
a, b, c, d
34+
$(esc(a)), $(esc(b)), $(esc(c)), $(esc(d))
35+
end
4236
end
4337

4438
@inline function store_u64!(x::AbstractVector{UInt32}, u::UInt64, idx)
@@ -219,12 +213,12 @@ end
219213
v3 = vgather(state, $idx3)
220214

221215
for i = 1:doublerounds
222-
v0, v1, v2, v3 = _QR!(v0, v1, v2, v3)
216+
v0, v1, v2, v3 = @_QR!(v0, v1, v2, v3)
223217
v1 = rotatevector(v1, Val(-1))
224218
v2 = rotatevector(v2, Val(-2))
225219
v3 = rotatevector(v3, Val(-3))
226220

227-
v0, v1, v2, v3 = _QR!(v0, v1, v2, v3)
221+
v0, v1, v2, v3 = @_QR!(v0, v1, v2, v3)
228222
v1 = rotatevector(v1, Val(1))
229223
v2 = rotatevector(v2, Val(2))
230224
v3 = rotatevector(v3, Val(3))
@@ -264,7 +258,7 @@ function _cuda_chacha_rounds!(state, doublerounds)
264258

265259
# Only operate on a slice of the state corresponding to
266260
# the thread block
267-
state_slice = view(state, block+1:block+16)
261+
slice = view(state, block+1:block+16)
268262

269263
# Pre-compute the indices that this thread will use to
270264
# perform its diagonal rounds
@@ -279,11 +273,11 @@ function _cuda_chacha_rounds!(state, doublerounds)
279273
# Each thread in the same block runs its rounds in parallel
280274
for _ = 1:doublerounds
281275
# Columnar rounds
282-
_QR!(state_slice, i, i + 4, i + 8, i + 12)
276+
@_QR!(slice[i], slice[i+4], slice[i+8], slice[i+12])
283277
CUDA.threadfence_block()
284278

285279
# Diagonal rounds
286-
_QR!(state_slice, dgc1, dgc2, dgc3, dgc4)
280+
@_QR!(slice[dgc1], slice[dgc2], slice[dgc3], slice[dgc4])
287281
CUDA.threadfence_block()
288282
end
289283

test/test_chacha.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ end
157157
# Ref: IETF RFC 8439, Sec. 2.1.1
158158
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.1.1
159159
state = MVector{4,UInt32}([0x11111111, 0x01020304, 0x9b8d6f43, 0x01234567])
160-
ChaCha._QR!(state, 1, 2, 3, 4)
160+
ChaCha.@_QR!(state[1], state[2], state[3], state[4])
161161

162162
expected_state = SVector{4,UInt32}([0xea2a92f4, 0xcb1cf8ce, 0x4581472e, 0x5881c4bb])
163163

@@ -174,7 +174,7 @@ end
174174
])
175175
initial_state = deepcopy(state)
176176

177-
ChaCha._QR!(state, 3, 8, 9, 14)
177+
ChaCha.@_QR!(state[3], state[8], state[9], state[14])
178178

179179
mask = trues(length(state))
180180
mask[3] = mask[8] = mask[9] = mask[14] = false
@@ -224,11 +224,11 @@ end
224224

225225
function kernel(state, a, b, c, d)
226226
i = 4 * (threadIdx().x - 1)
227-
ChaCha._QR!(state, i + a, i + b, i + c, i + d)
227+
ChaCha.@_QR!(state[i+a], state[i+b], state[i+c], state[i+d])
228228
nothing
229229
end
230230

231-
ChaCha._QR!(state, 1, 2, 3, 4)
231+
ChaCha.@_QR!(state[1], state[2], state[3], state[4])
232232
CUDA.@sync @cuda threads=1024 kernel(state_gpu, 1, 2, 3, 4)
233233

234234
@test state_gpu == CuArray(collect(repeat(state, 1024)))

0 commit comments

Comments
 (0)