Skip to content

Commit a362c8a

Browse files
committed
Run the RFC 8439 test suite for ChaCha on both CPU and GPU
1 parent 27dad48 commit a362c8a

File tree

1 file changed

+104
-84
lines changed

1 file changed

+104
-84
lines changed

test/test_chacha.jl

Lines changed: 104 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -12,71 +12,14 @@ using CUDA
1212
using StaticArrays
1313
using Test
1414

15-
@testset "ChaCha tests" begin
16-
@testset "Quarter-round function tests" begin
17-
# Ref: IETF RFC 8439, Sec. 2.1.1
18-
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.1.1
19-
state = MVector{4,UInt32}([0x11111111, 0x01020304, 0x9b8d6f43, 0x01234567])
20-
ChaCha._QR!(state, 1, 2, 3, 4)
21-
22-
expected_state = SVector{4,UInt32}([0xea2a92f4, 0xcb1cf8ce, 0x4581472e, 0x5881c4bb])
23-
24-
@test state == expected_state
25-
26-
# Ref: IETF RFC 8439, Sec. 2.2.1
27-
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.2.1
28-
29-
state = MVector{16,UInt32}([
30-
0x879531e0, 0xc5ecf37d, 0x516461b1, 0xc9a62f8a,
31-
0x44c20ef3, 0x3390af7f, 0xd9fc690b, 0x2a5f714c,
32-
0x53372767, 0xb00a5631, 0x974c541a, 0x359e9963,
33-
0x5c971061, 0x3d631689, 0x2098d9d6, 0x91dbd320,
34-
])
35-
initial_state = deepcopy(state)
36-
37-
ChaCha._QR!(state, 3, 8, 9, 14)
38-
39-
mask = trues(length(state))
40-
mask[3] = mask[8] = mask[9] = mask[14] = false
41-
42-
@test state[mask] == initial_state[mask]
43-
@test state[@.(~mask)] == [
44-
0xbdb886dc, 0xcfacafd2, 0xe46bea80, 0xccc07c79
45-
]
46-
end
47-
48-
@testset "Test initial state" begin
49-
# Ref: IETF RFC 8439, Sec. 2.2.1
50-
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.2.1
51-
key = SVector{8,UInt32}([
52-
0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
53-
0x13121110, 0x17161514, 0x1b1a1918, 0x1f1e1d1c,
54-
])
55-
nonce = 0x000000004a000000
56-
counter = 0x0900000000000001
57-
58-
state_set = MVector{16,UInt32}(undef)
59-
state_add = MVector{16,UInt32}(zeros(UInt32, 16))
60-
ChaCha._chacha_set_initial_state!(state_set, key, nonce, counter)
61-
ChaCha._chacha_add_initial_state!(state_add, key, nonce, counter)
62-
test_vector = MVector{16,UInt32}([
63-
0x61707865, 0x3320646e, 0x79622d32, 0x6b206574,
64-
0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
65-
0x13121110, 0x17161514, 0x1b1a1918, 0x1f1e1d1c,
66-
0x00000001, 0x09000000, 0x4a000000, 0x00000000,
67-
])
68-
69-
@test state_set == test_vector
70-
@test state_add == test_vector
71-
end
72-
15+
function chacha_blocks_test_suite(T)
7316
@testset "Test chacha_blocks!" begin
7417
# Ref: IETF RFC 8439, Sec. 2.3.2
7518
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.3.2
7619
key = SVector{8,UInt32}([
7720
0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
7821
0x13121110, 0x17161514, 0x1b1a1918, 0x1f1e1d1c,
79-
])
22+
]) |> T
8023
nonce = 0x000000004a000000
8124
counter = 0x0900000000000001
8225

@@ -85,8 +28,8 @@ using Test
8528
0xc7f4d1c7, 0x0368c033, 0x9aaa2204, 0x4e6cd4c3,
8629
0x466482d2, 0x09aa9f07, 0x05d7c214, 0xa2028bd9,
8730
0xd19c12b5, 0xb94e16de, 0xe883d0cb, 0x4e3c50a2,
88-
])
89-
state = MVector{16,UInt32}(undef)
31+
]) |> T
32+
state = MVector{16,UInt32}(undef) |> T
9033
@test ChaCha.chacha_blocks!(state, key, nonce, counter) == counter + 1
9134
@test state == test_vector
9235

@@ -95,30 +38,30 @@ using Test
9538

9639
# Test Vector #1:
9740
# ==============
98-
key = SVector{8,UInt32}(zeros(UInt32, 8))
41+
key = SVector{8,UInt32}(zeros(UInt32, 8)) |> T
9942
nonce = UInt64(0)
10043
counter = UInt64(0)
10144
test_vector = SVector{16,UInt32}([
10245
0xade0b876, 0x903df1a0, 0xe56a5d40, 0x28bd8653,
10346
0xb819d2bd, 0x1aed8da0, 0xccef36a8, 0xc70d778b,
10447
0x7c5941da, 0x8d485751, 0x3fe02477, 0x374ad8b8,
10548
0xf4b8436a, 0x1ca11815, 0x69b687c3, 0x8665eeb2
106-
])
107-
state = MVector{16,UInt32}(undef)
49+
]) |> T
50+
state = MVector{16,UInt32}(undef) |> T
10851
@test ChaCha.chacha_blocks!(state, key, nonce, counter) == counter + 1
10952
@test state == test_vector
11053

11154
# Test Vector #2:
11255
# ==============
113-
key = SVector{8,UInt32}(zeros(UInt32, 8))
56+
key = SVector{8,UInt32}(zeros(UInt32, 8)) |> T
11457
nonce = UInt64(0)
11558
counter = UInt64(1)
11659
test_vector = SVector{16,UInt32}([
11760
0xbee7079f, 0x7a385155, 0x7c97ba98, 0x0d082d73,
11861
0xa0290fcb, 0x6965e348, 0x3e53c612, 0xed7aee32,
11962
0x7621b729, 0x434ee69c, 0xb03371d5, 0xd539d874,
12063
0x281fed31, 0x45fb0a51, 0x1f0ae1ac, 0x6f4d794b
121-
])
64+
]) |> T
12265
@test ChaCha.chacha_blocks!(state, key, nonce, counter) == counter + 1
12366
@test state == test_vector
12467

@@ -127,13 +70,13 @@ using Test
12770
key = SVector{8,UInt32}([
12871
0x00000000, 0x00000000, 0x00000000, 0x00000000,
12972
0x00000000, 0x00000000, 0x00000000, 0x01000000,
130-
])
73+
]) |> T
13174
test_vector = SVector{16,UInt32}([
13275
0x2452eb3a, 0x9249f8ec, 0x8d829d9b, 0xddd4ceb1,
13376
0xe8252083, 0x60818b01, 0xf38422b8, 0x5aaa49c9,
13477
0xbb00ca8e, 0xda3ba7b4, 0xc4b592d1, 0xfdf2732f,
13578
0x4436274e, 0x2561b3c8, 0xebdd4aa6, 0xa0136c00
136-
])
79+
]) |> T
13780
nonce = UInt64(0)
13881
counter = UInt64(1)
13982
@test ChaCha.chacha_blocks!(state, key, nonce, counter) == counter + 1
@@ -144,15 +87,15 @@ using Test
14487
key = SVector{8,UInt32}([
14588
0x0000ff00, 0x00000000, 0x00000000, 0x00000000,
14689
0x00000000, 0x00000000, 0x00000000, 0x00000000,
147-
])
90+
]) |> T
14891
nonce = UInt64(0)
14992
counter = UInt64(2)
15093
test_vector = SVector{16,UInt32}([
15194
0xfb4dd572, 0x4bc42ef1, 0xdf922636, 0x327f1394,
15295
0xa78dea8f, 0x5e269039, 0xa1bebbc1, 0xcaf09aae,
15396
0xa25ab213, 0x48a6b46c, 0x1b9d9bcb, 0x092c5be6,
15497
0x546ca624, 0x1bec45d5, 0x87f47473, 0x96f0992e
155-
])
98+
]) |> T
15699
@test ChaCha.chacha_blocks!(state, key, nonce, counter) == counter + 1
157100
@test state == test_vector
158101

@@ -166,10 +109,72 @@ using Test
166109
0x88228b1a, 0x96a4dfb3, 0x5b76ab72, 0xc727ee54,
167110
0x0e0e978a, 0xf3145c95, 0x1b748ea8, 0xf786c297,
168111
0x99c28f5f, 0x628314e8, 0x398a19fa, 0x6ded1b53
169-
])
112+
]) |> T
170113
@test ChaCha.chacha_blocks!(state, key, nonce, counter) == counter + 1
171114
@test state == test_vector
172115
end
116+
117+
end
118+
119+
@testset "ChaCha tests" begin
120+
@testset "Quarter-round function tests" begin
121+
# Ref: IETF RFC 8439, Sec. 2.1.1
122+
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.1.1
123+
state = MVector{4,UInt32}([0x11111111, 0x01020304, 0x9b8d6f43, 0x01234567])
124+
ChaCha._QR!(state, 1, 2, 3, 4)
125+
126+
expected_state = SVector{4,UInt32}([0xea2a92f4, 0xcb1cf8ce, 0x4581472e, 0x5881c4bb])
127+
128+
@test state == expected_state
129+
130+
# Ref: IETF RFC 8439, Sec. 2.2.1
131+
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.2.1
132+
133+
state = MVector{16,UInt32}([
134+
0x879531e0, 0xc5ecf37d, 0x516461b1, 0xc9a62f8a,
135+
0x44c20ef3, 0x3390af7f, 0xd9fc690b, 0x2a5f714c,
136+
0x53372767, 0xb00a5631, 0x974c541a, 0x359e9963,
137+
0x5c971061, 0x3d631689, 0x2098d9d6, 0x91dbd320,
138+
])
139+
initial_state = deepcopy(state)
140+
141+
ChaCha._QR!(state, 3, 8, 9, 14)
142+
143+
mask = trues(length(state))
144+
mask[3] = mask[8] = mask[9] = mask[14] = false
145+
146+
@test state[mask] == initial_state[mask]
147+
@test state[@.(~mask)] == [
148+
0xbdb886dc, 0xcfacafd2, 0xe46bea80, 0xccc07c79
149+
]
150+
end
151+
152+
@testset "Test initial state" begin
153+
# Ref: IETF RFC 8439, Sec. 2.2.1
154+
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.2.1
155+
key = SVector{8,UInt32}([
156+
0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
157+
0x13121110, 0x17161514, 0x1b1a1918, 0x1f1e1d1c,
158+
])
159+
nonce = 0x000000004a000000
160+
counter = 0x0900000000000001
161+
162+
state_set = MVector{16,UInt32}(undef)
163+
state_add = MVector{16,UInt32}(zeros(UInt32, 16))
164+
ChaCha._chacha_set_initial_state!(state_set, key, nonce, counter)
165+
ChaCha._chacha_add_initial_state!(state_add, key, nonce, counter)
166+
test_vector = MVector{16,UInt32}([
167+
0x61707865, 0x3320646e, 0x79622d32, 0x6b206574,
168+
0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
169+
0x13121110, 0x17161514, 0x1b1a1918, 0x1f1e1d1c,
170+
0x00000001, 0x09000000, 0x4a000000, 0x00000000,
171+
])
172+
173+
@test state_set == test_vector
174+
@test state_add == test_vector
175+
end
176+
177+
chacha_blocks_test_suite(identity)
173178
end
174179

175180
@testset "CUDA ChaCha tests" begin
@@ -193,21 +198,36 @@ end
193198
@test state_gpu == CuArray(collect(repeat(state, 1024)))
194199
end
195200

196-
@testset "Test chacha_blocks!" begin
197-
for _ = 1:64
198-
state = zeros(UInt32, 1024)
199-
state_gpu = CUDA.zeros(UInt32, 1024)
200-
key = SVector{8,UInt32}(rand(UInt32, 8))
201-
key_gpu = CuArray(key)
202-
nonce = rand(UInt64)
203-
counter = UInt64(0)
201+
chacha_blocks_test_suite(x -> CuArray(x))
204202

205-
ctr = chacha_blocks!(state, key, nonce, counter, 1024 ÷ 16)
206-
CUDA.@sync ctr_gpu = chacha_blocks!(state_gpu, key_gpu, nonce, counter, 1024 ÷ 16)
203+
@testset "Compare chacha_blocks! output with CPU output" begin
204+
# Test with key, nonce, and counter equal to zero
205+
state = zeros(UInt32, 1024)
206+
state_gpu = CUDA.zeros(UInt32, 1024)
207+
key = SVector{8,UInt32}(zeros(UInt32, 8))
208+
key_gpu = CuArray(key)
209+
nonce = UInt64(0)
210+
counter = UInt64(0)
207211

208-
@test ctr == ctr_gpu
209-
@test state_gpu == CuArray(state)
210-
end
212+
ctr = chacha_blocks!(state, key, nonce, counter, 1024 ÷ 16)
213+
CUDA.@sync ctr_gpu = chacha_blocks!(state_gpu, key_gpu, nonce, counter, 1024 ÷ 16)
214+
215+
@test ctr == ctr_gpu
216+
@test state_gpu == CuArray(state)
217+
218+
# Test with randomized nonce and key
219+
state = zeros(UInt32, 2^16)
220+
state_gpu = CUDA.zeros(UInt32, length(state))
221+
key = SVector{8,UInt32}(rand(UInt32, 8))
222+
key_gpu = CuArray(key)
223+
nonce = rand(UInt64)
224+
counter = UInt64(0)
225+
226+
ctr = chacha_blocks!(state, key, nonce, counter, length(state) ÷ 16)
227+
CUDA.@sync ctr_gpu = chacha_blocks!(state_gpu, key_gpu, nonce, counter, length(state) ÷ 16)
228+
229+
@test ctr == ctr_gpu
230+
@test state_gpu == CuArray(state)
211231
end
212232
end
213233
end

0 commit comments

Comments
 (0)