@@ -12,71 +12,14 @@ using CUDA
12
12
using StaticArrays
13
13
using Test
14
14
15
- @testset " ChaCha tests" begin
16
- @testset " Quarter-round function tests" begin
17
- # Ref: IETF RFC 8439, Sec. 2.1.1
18
- # https://datatracker.ietf.org/doc/html/rfc8439#section-2.1.1
19
- state = MVector {4,UInt32} ([0x11111111 , 0x01020304 , 0x9b8d6f43 , 0x01234567 ])
20
- ChaCha. _QR! (state, 1 , 2 , 3 , 4 )
21
-
22
- expected_state = SVector {4,UInt32} ([0xea2a92f4 , 0xcb1cf8ce , 0x4581472e , 0x5881c4bb ])
23
-
24
- @test state == expected_state
25
-
26
- # Ref: IETF RFC 8439, Sec. 2.2.1
27
- # https://datatracker.ietf.org/doc/html/rfc8439#section-2.2.1
28
-
29
- state = MVector {16,UInt32} ([
30
- 0x879531e0 , 0xc5ecf37d , 0x516461b1 , 0xc9a62f8a ,
31
- 0x44c20ef3 , 0x3390af7f , 0xd9fc690b , 0x2a5f714c ,
32
- 0x53372767 , 0xb00a5631 , 0x974c541a , 0x359e9963 ,
33
- 0x5c971061 , 0x3d631689 , 0x2098d9d6 , 0x91dbd320 ,
34
- ])
35
- initial_state = deepcopy (state)
36
-
37
- ChaCha. _QR! (state, 3 , 8 , 9 , 14 )
38
-
39
- mask = trues (length (state))
40
- mask[3 ] = mask[8 ] = mask[9 ] = mask[14 ] = false
41
-
42
- @test state[mask] == initial_state[mask]
43
- @test state[@. (~ mask)] == [
44
- 0xbdb886dc , 0xcfacafd2 , 0xe46bea80 , 0xccc07c79
45
- ]
46
- end
47
-
48
- @testset " Test initial state" begin
49
- # Ref: IETF RFC 8439, Sec. 2.2.1
50
- # https://datatracker.ietf.org/doc/html/rfc8439#section-2.2.1
51
- key = SVector {8,UInt32} ([
52
- 0x03020100 , 0x07060504 , 0x0b0a0908 , 0x0f0e0d0c ,
53
- 0x13121110 , 0x17161514 , 0x1b1a1918 , 0x1f1e1d1c ,
54
- ])
55
- nonce = 0x000000004a000000
56
- counter = 0x0900000000000001
57
-
58
- state_set = MVector {16,UInt32} (undef)
59
- state_add = MVector {16,UInt32} (zeros (UInt32, 16 ))
60
- ChaCha. _chacha_set_initial_state! (state_set, key, nonce, counter)
61
- ChaCha. _chacha_add_initial_state! (state_add, key, nonce, counter)
62
- test_vector = MVector {16,UInt32} ([
63
- 0x61707865 , 0x3320646e , 0x79622d32 , 0x6b206574 ,
64
- 0x03020100 , 0x07060504 , 0x0b0a0908 , 0x0f0e0d0c ,
65
- 0x13121110 , 0x17161514 , 0x1b1a1918 , 0x1f1e1d1c ,
66
- 0x00000001 , 0x09000000 , 0x4a000000 , 0x00000000 ,
67
- ])
68
-
69
- @test state_set == test_vector
70
- @test state_add == test_vector
71
- end
72
-
15
+ function chacha_blocks_test_suite (T)
73
16
@testset " Test chacha_blocks!" begin
74
17
# Ref: IETF RFC 8439, Sec. 2.3.2
75
18
# https://datatracker.ietf.org/doc/html/rfc8439#section-2.3.2
76
19
key = SVector {8,UInt32} ([
77
20
0x03020100 , 0x07060504 , 0x0b0a0908 , 0x0f0e0d0c ,
78
21
0x13121110 , 0x17161514 , 0x1b1a1918 , 0x1f1e1d1c ,
79
- ])
22
+ ]) |> T
80
23
nonce = 0x000000004a000000
81
24
counter = 0x0900000000000001
82
25
@@ -85,8 +28,8 @@ using Test
85
28
0xc7f4d1c7 , 0x0368c033 , 0x9aaa2204 , 0x4e6cd4c3 ,
86
29
0x466482d2 , 0x09aa9f07 , 0x05d7c214 , 0xa2028bd9 ,
87
30
0xd19c12b5 , 0xb94e16de , 0xe883d0cb , 0x4e3c50a2 ,
88
- ])
89
- state = MVector {16,UInt32} (undef)
31
+ ]) |> T
32
+ state = MVector {16,UInt32} (undef) |> T
90
33
@test ChaCha. chacha_blocks! (state, key, nonce, counter) == counter + 1
91
34
@test state == test_vector
92
35
@@ -95,30 +38,30 @@ using Test
95
38
96
39
# Test Vector #1:
97
40
# ==============
98
- key = SVector {8,UInt32} (zeros (UInt32, 8 ))
41
+ key = SVector {8,UInt32} (zeros (UInt32, 8 )) |> T
99
42
nonce = UInt64 (0 )
100
43
counter = UInt64 (0 )
101
44
test_vector = SVector {16,UInt32} ([
102
45
0xade0b876 , 0x903df1a0 , 0xe56a5d40 , 0x28bd8653 ,
103
46
0xb819d2bd , 0x1aed8da0 , 0xccef36a8 , 0xc70d778b ,
104
47
0x7c5941da , 0x8d485751 , 0x3fe02477 , 0x374ad8b8 ,
105
48
0xf4b8436a , 0x1ca11815 , 0x69b687c3 , 0x8665eeb2
106
- ])
107
- state = MVector {16,UInt32} (undef)
49
+ ]) |> T
50
+ state = MVector {16,UInt32} (undef) |> T
108
51
@test ChaCha. chacha_blocks! (state, key, nonce, counter) == counter + 1
109
52
@test state == test_vector
110
53
111
54
# Test Vector #2:
112
55
# ==============
113
- key = SVector {8,UInt32} (zeros (UInt32, 8 ))
56
+ key = SVector {8,UInt32} (zeros (UInt32, 8 )) |> T
114
57
nonce = UInt64 (0 )
115
58
counter = UInt64 (1 )
116
59
test_vector = SVector {16,UInt32} ([
117
60
0xbee7079f , 0x7a385155 , 0x7c97ba98 , 0x0d082d73 ,
118
61
0xa0290fcb , 0x6965e348 , 0x3e53c612 , 0xed7aee32 ,
119
62
0x7621b729 , 0x434ee69c , 0xb03371d5 , 0xd539d874 ,
120
63
0x281fed31 , 0x45fb0a51 , 0x1f0ae1ac , 0x6f4d794b
121
- ])
64
+ ]) |> T
122
65
@test ChaCha. chacha_blocks! (state, key, nonce, counter) == counter + 1
123
66
@test state == test_vector
124
67
@@ -127,13 +70,13 @@ using Test
127
70
key = SVector {8,UInt32} ([
128
71
0x00000000 , 0x00000000 , 0x00000000 , 0x00000000 ,
129
72
0x00000000 , 0x00000000 , 0x00000000 , 0x01000000 ,
130
- ])
73
+ ]) |> T
131
74
test_vector = SVector {16,UInt32} ([
132
75
0x2452eb3a , 0x9249f8ec , 0x8d829d9b , 0xddd4ceb1 ,
133
76
0xe8252083 , 0x60818b01 , 0xf38422b8 , 0x5aaa49c9 ,
134
77
0xbb00ca8e , 0xda3ba7b4 , 0xc4b592d1 , 0xfdf2732f ,
135
78
0x4436274e , 0x2561b3c8 , 0xebdd4aa6 , 0xa0136c00
136
- ])
79
+ ]) |> T
137
80
nonce = UInt64 (0 )
138
81
counter = UInt64 (1 )
139
82
@test ChaCha. chacha_blocks! (state, key, nonce, counter) == counter + 1
@@ -144,15 +87,15 @@ using Test
144
87
key = SVector {8,UInt32} ([
145
88
0x0000ff00 , 0x00000000 , 0x00000000 , 0x00000000 ,
146
89
0x00000000 , 0x00000000 , 0x00000000 , 0x00000000 ,
147
- ])
90
+ ]) |> T
148
91
nonce = UInt64 (0 )
149
92
counter = UInt64 (2 )
150
93
test_vector = SVector {16,UInt32} ([
151
94
0xfb4dd572 , 0x4bc42ef1 , 0xdf922636 , 0x327f1394 ,
152
95
0xa78dea8f , 0x5e269039 , 0xa1bebbc1 , 0xcaf09aae ,
153
96
0xa25ab213 , 0x48a6b46c , 0x1b9d9bcb , 0x092c5be6 ,
154
97
0x546ca624 , 0x1bec45d5 , 0x87f47473 , 0x96f0992e
155
- ])
98
+ ]) |> T
156
99
@test ChaCha. chacha_blocks! (state, key, nonce, counter) == counter + 1
157
100
@test state == test_vector
158
101
@@ -166,10 +109,72 @@ using Test
166
109
0x88228b1a , 0x96a4dfb3 , 0x5b76ab72 , 0xc727ee54 ,
167
110
0x0e0e978a , 0xf3145c95 , 0x1b748ea8 , 0xf786c297 ,
168
111
0x99c28f5f , 0x628314e8 , 0x398a19fa , 0x6ded1b53
169
- ])
112
+ ]) |> T
170
113
@test ChaCha. chacha_blocks! (state, key, nonce, counter) == counter + 1
171
114
@test state == test_vector
172
115
end
116
+
117
+ end
118
+
119
+ @testset " ChaCha tests" begin
120
+ @testset " Quarter-round function tests" begin
121
+ # Ref: IETF RFC 8439, Sec. 2.1.1
122
+ # https://datatracker.ietf.org/doc/html/rfc8439#section-2.1.1
123
+ state = MVector {4,UInt32} ([0x11111111 , 0x01020304 , 0x9b8d6f43 , 0x01234567 ])
124
+ ChaCha. _QR! (state, 1 , 2 , 3 , 4 )
125
+
126
+ expected_state = SVector {4,UInt32} ([0xea2a92f4 , 0xcb1cf8ce , 0x4581472e , 0x5881c4bb ])
127
+
128
+ @test state == expected_state
129
+
130
+ # Ref: IETF RFC 8439, Sec. 2.2.1
131
+ # https://datatracker.ietf.org/doc/html/rfc8439#section-2.2.1
132
+
133
+ state = MVector {16,UInt32} ([
134
+ 0x879531e0 , 0xc5ecf37d , 0x516461b1 , 0xc9a62f8a ,
135
+ 0x44c20ef3 , 0x3390af7f , 0xd9fc690b , 0x2a5f714c ,
136
+ 0x53372767 , 0xb00a5631 , 0x974c541a , 0x359e9963 ,
137
+ 0x5c971061 , 0x3d631689 , 0x2098d9d6 , 0x91dbd320 ,
138
+ ])
139
+ initial_state = deepcopy (state)
140
+
141
+ ChaCha. _QR! (state, 3 , 8 , 9 , 14 )
142
+
143
+ mask = trues (length (state))
144
+ mask[3 ] = mask[8 ] = mask[9 ] = mask[14 ] = false
145
+
146
+ @test state[mask] == initial_state[mask]
147
+ @test state[@. (~ mask)] == [
148
+ 0xbdb886dc , 0xcfacafd2 , 0xe46bea80 , 0xccc07c79
149
+ ]
150
+ end
151
+
152
+ @testset " Test initial state" begin
153
+ # Ref: IETF RFC 8439, Sec. 2.2.1
154
+ # https://datatracker.ietf.org/doc/html/rfc8439#section-2.2.1
155
+ key = SVector {8,UInt32} ([
156
+ 0x03020100 , 0x07060504 , 0x0b0a0908 , 0x0f0e0d0c ,
157
+ 0x13121110 , 0x17161514 , 0x1b1a1918 , 0x1f1e1d1c ,
158
+ ])
159
+ nonce = 0x000000004a000000
160
+ counter = 0x0900000000000001
161
+
162
+ state_set = MVector {16,UInt32} (undef)
163
+ state_add = MVector {16,UInt32} (zeros (UInt32, 16 ))
164
+ ChaCha. _chacha_set_initial_state! (state_set, key, nonce, counter)
165
+ ChaCha. _chacha_add_initial_state! (state_add, key, nonce, counter)
166
+ test_vector = MVector {16,UInt32} ([
167
+ 0x61707865 , 0x3320646e , 0x79622d32 , 0x6b206574 ,
168
+ 0x03020100 , 0x07060504 , 0x0b0a0908 , 0x0f0e0d0c ,
169
+ 0x13121110 , 0x17161514 , 0x1b1a1918 , 0x1f1e1d1c ,
170
+ 0x00000001 , 0x09000000 , 0x4a000000 , 0x00000000 ,
171
+ ])
172
+
173
+ @test state_set == test_vector
174
+ @test state_add == test_vector
175
+ end
176
+
177
+ chacha_blocks_test_suite (identity)
173
178
end
174
179
175
180
@testset " CUDA ChaCha tests" begin
@@ -193,21 +198,36 @@ end
193
198
@test state_gpu == CuArray (collect (repeat (state, 1024 )))
194
199
end
195
200
196
- @testset " Test chacha_blocks!" begin
197
- for _ = 1 : 64
198
- state = zeros (UInt32, 1024 )
199
- state_gpu = CUDA. zeros (UInt32, 1024 )
200
- key = SVector {8,UInt32} (rand (UInt32, 8 ))
201
- key_gpu = CuArray (key)
202
- nonce = rand (UInt64)
203
- counter = UInt64 (0 )
201
+ chacha_blocks_test_suite (x -> CuArray (x))
204
202
205
- ctr = chacha_blocks! (state, key, nonce, counter, 1024 ÷ 16 )
206
- CUDA. @sync ctr_gpu = chacha_blocks! (state_gpu, key_gpu, nonce, counter, 1024 ÷ 16 )
203
+ @testset " Compare chacha_blocks! output with CPU output" begin
204
+ # Test with key, nonce, and counter equal to zero
205
+ state = zeros (UInt32, 1024 )
206
+ state_gpu = CUDA. zeros (UInt32, 1024 )
207
+ key = SVector {8,UInt32} (zeros (UInt32, 8 ))
208
+ key_gpu = CuArray (key)
209
+ nonce = UInt64 (0 )
210
+ counter = UInt64 (0 )
207
211
208
- @test ctr == ctr_gpu
209
- @test state_gpu == CuArray (state)
210
- end
212
+ ctr = chacha_blocks! (state, key, nonce, counter, 1024 ÷ 16 )
213
+ CUDA. @sync ctr_gpu = chacha_blocks! (state_gpu, key_gpu, nonce, counter, 1024 ÷ 16 )
214
+
215
+ @test ctr == ctr_gpu
216
+ @test state_gpu == CuArray (state)
217
+
218
+ # Test with randomized nonce and key
219
+ state = zeros (UInt32, 2 ^ 16 )
220
+ state_gpu = CUDA. zeros (UInt32, length (state))
221
+ key = SVector {8,UInt32} (rand (UInt32, 8 ))
222
+ key_gpu = CuArray (key)
223
+ nonce = rand (UInt64)
224
+ counter = UInt64 (0 )
225
+
226
+ ctr = chacha_blocks! (state, key, nonce, counter, length (state) ÷ 16 )
227
+ CUDA. @sync ctr_gpu = chacha_blocks! (state_gpu, key_gpu, nonce, counter, length (state) ÷ 16 )
228
+
229
+ @test ctr == ctr_gpu
230
+ @test state_gpu == CuArray (state)
211
231
end
212
232
end
213
233
end
0 commit comments