Skip to content

Commit b8043a0

Browse files
committed
Add some tests
1 parent 838b45b commit b8043a0

File tree

2 files changed

+107
-24
lines changed

2 files changed

+107
-24
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ publish = false
88
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
99

1010
[dependencies]
11+
12+
[dev-dependencies]
13+
rand = "0.7.3"

src/convolution.rs

Lines changed: 104 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,32 @@
1+
macro_rules! modulus {
2+
($($name:ident),*) => {
3+
$(
4+
#[derive(Copy, Clone, Eq, PartialEq)]
5+
enum $name {}
6+
7+
impl Modulus for $name {
8+
const VALUE: u32 = $name as _;
9+
const HINT_VALUE_IS_PRIME: bool = true;
10+
11+
fn butterfly_cache() -> &'static ::std::thread::LocalKey<::std::cell::RefCell<::std::option::Option<crate::modint::ButterflyCache<Self>>>> {
12+
thread_local! {
13+
static BUTTERFLY_CACHE: ::std::cell::RefCell<::std::option::Option<crate::modint::ButterflyCache<$name>>> = ::std::default::Default::default();
14+
}
15+
&BUTTERFLY_CACHE
16+
}
17+
}
18+
)*
19+
};
20+
}
21+
122
use crate::{
223
internal_bit, internal_math,
324
modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt},
425
};
526
use std::{
6-
cell::RefCell,
727
cmp,
828
convert::{TryFrom, TryInto as _},
929
fmt,
10-
thread::LocalKey,
1130
};
1231

1332
#[allow(clippy::many_single_char_names)]
@@ -77,28 +96,7 @@ pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
7796
const M1M2: u64 = M1 * M2;
7897
const M1M2M3: u64 = M1M2.wrapping_mul(M3);
7998

80-
macro_rules! moduli {
81-
($($name:ident),*) => {
82-
$(
83-
#[derive(Copy, Clone, Eq, PartialEq)]
84-
enum $name {}
85-
86-
impl Modulus for $name {
87-
const VALUE: u32 = $name as _;
88-
const HINT_VALUE_IS_PRIME: bool = true;
89-
90-
fn butterfly_cache() -> &'static LocalKey<RefCell<Option<ButterflyCache<Self>>>> {
91-
thread_local! {
92-
static BUTTERFLY_CACHE: RefCell<Option<ButterflyCache<$name>>> = RefCell::default();
93-
}
94-
&BUTTERFLY_CACHE
95-
}
96-
}
97-
)*
98-
};
99-
}
100-
101-
moduli!(M1, M2, M3);
99+
modulus!(M1, M2, M3);
102100

103101
if a.is_empty() || b.is_empty() {
104102
return vec![];
@@ -230,3 +228,85 @@ fn prepare<M: Modulus>() -> ButterflyCache<M> {
230228
.collect();
231229
ButterflyCache { sum_e, sum_ie }
232230
}
231+
232+
#[cfg(test)]
233+
mod tests {
234+
use crate::modint::{Mod998244353, Modulus, StaticModInt};
235+
use rand::{rngs::ThreadRng, Rng as _};
236+
237+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L73-L85
238+
#[test]
239+
fn mid() {
240+
const N: usize = 1234;
241+
const M: usize = 2345;
242+
243+
let mut rng = rand::thread_rng();
244+
let mut gen_values = |n| gen_values::<Mod998244353>(&mut rng, n);
245+
let (a, b) = (gen_values(N), gen_values(M));
246+
assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
247+
}
248+
249+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L87-L118
250+
#[test]
251+
fn simple_s_mod() {
252+
const M1: u32 = 998_244_353;
253+
const M2: u32 = 924_844_033;
254+
255+
modulus!(M1, M2);
256+
257+
fn test<M: Modulus>(rng: &mut ThreadRng) {
258+
let mut gen_values = |n| gen_values::<Mod998244353>(rng, n);
259+
for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
260+
let (a, b) = (gen_values(n), gen_values(m));
261+
assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
262+
}
263+
}
264+
265+
let mut rng = rand::thread_rng();
266+
test::<M1>(&mut rng);
267+
test::<M2>(&mut rng);
268+
}
269+
270+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L358-L371
271+
#[test]
272+
fn conv641() {
273+
const M: u32 = 641;
274+
modulus!(M);
275+
276+
let mut rng = rand::thread_rng();
277+
let mut gen_values = |n| gen_values::<M>(&mut rng, n);
278+
let (a, b) = (gen_values(64), gen_values(65));
279+
assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
280+
}
281+
282+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L373-L386
283+
#[test]
284+
fn conv18433() {
285+
const M: u32 = 18433;
286+
modulus!(M);
287+
288+
let mut rng = rand::thread_rng();
289+
let mut gen_values = |n| gen_values::<M>(&mut rng, n);
290+
let (a, b) = (gen_values(1024), gen_values(1025));
291+
assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
292+
}
293+
294+
#[allow(clippy::many_single_char_names)]
295+
fn conv_naive<M: Modulus>(
296+
a: &[StaticModInt<M>],
297+
b: &[StaticModInt<M>],
298+
) -> Vec<StaticModInt<M>> {
299+
let (n, m) = (a.len(), b.len());
300+
let mut c = vec![StaticModInt::raw(0); n + m - 1];
301+
for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
302+
c[i + j] += a[i] * b[j];
303+
}
304+
c
305+
}
306+
307+
fn gen_values<M: Modulus>(rng: &mut ThreadRng, n: usize) -> Vec<StaticModInt<M>> {
308+
(0..n)
309+
.map(|_| StaticModInt::raw(rng.gen_range(0, M::VALUE)))
310+
.collect()
311+
}
312+
}

0 commit comments

Comments
 (0)