|
| 1 | +macro_rules! modulus { |
| 2 | + ($($name:ident),*) => { |
| 3 | + $( |
| 4 | + #[derive(Copy, Clone, Eq, PartialEq)] |
| 5 | + enum $name {} |
| 6 | + |
| 7 | + impl Modulus for $name { |
| 8 | + const VALUE: u32 = $name as _; |
| 9 | + const HINT_VALUE_IS_PRIME: bool = true; |
| 10 | + |
| 11 | + fn butterfly_cache() -> &'static ::std::thread::LocalKey<::std::cell::RefCell<::std::option::Option<crate::modint::ButterflyCache<Self>>>> { |
| 12 | + thread_local! { |
| 13 | + static BUTTERFLY_CACHE: ::std::cell::RefCell<::std::option::Option<crate::modint::ButterflyCache<$name>>> = ::std::default::Default::default(); |
| 14 | + } |
| 15 | + &BUTTERFLY_CACHE |
| 16 | + } |
| 17 | + } |
| 18 | + )* |
| 19 | + }; |
| 20 | +} |
| 21 | + |
1 | 22 | use crate::{
|
2 | 23 | internal_bit, internal_math,
|
3 | 24 | modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt},
|
4 | 25 | };
|
5 | 26 | use std::{
|
6 |
| - cell::RefCell, |
7 | 27 | cmp,
|
8 | 28 | convert::{TryFrom, TryInto as _},
|
9 | 29 | fmt,
|
10 |
| - thread::LocalKey, |
11 | 30 | };
|
12 | 31 |
|
13 | 32 | #[allow(clippy::many_single_char_names)]
|
@@ -77,28 +96,7 @@ pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
|
77 | 96 | const M1M2: u64 = M1 * M2;
|
78 | 97 | const M1M2M3: u64 = M1M2.wrapping_mul(M3);
|
79 | 98 |
|
80 |
| - macro_rules! moduli { |
81 |
| - ($($name:ident),*) => { |
82 |
| - $( |
83 |
| - #[derive(Copy, Clone, Eq, PartialEq)] |
84 |
| - enum $name {} |
85 |
| - |
86 |
| - impl Modulus for $name { |
87 |
| - const VALUE: u32 = $name as _; |
88 |
| - const HINT_VALUE_IS_PRIME: bool = true; |
89 |
| - |
90 |
| - fn butterfly_cache() -> &'static LocalKey<RefCell<Option<ButterflyCache<Self>>>> { |
91 |
| - thread_local! { |
92 |
| - static BUTTERFLY_CACHE: RefCell<Option<ButterflyCache<$name>>> = RefCell::default(); |
93 |
| - } |
94 |
| - &BUTTERFLY_CACHE |
95 |
| - } |
96 |
| - } |
97 |
| - )* |
98 |
| - }; |
99 |
| - } |
100 |
| - |
101 |
| - moduli!(M1, M2, M3); |
| 99 | + modulus!(M1, M2, M3); |
102 | 100 |
|
103 | 101 | if a.is_empty() || b.is_empty() {
|
104 | 102 | return vec![];
|
@@ -230,3 +228,85 @@ fn prepare<M: Modulus>() -> ButterflyCache<M> {
|
230 | 228 | .collect();
|
231 | 229 | ButterflyCache { sum_e, sum_ie }
|
232 | 230 | }
|
| 231 | + |
| 232 | +#[cfg(test)] |
| 233 | +mod tests { |
| 234 | + use crate::modint::{Mod998244353, Modulus, StaticModInt}; |
| 235 | + use rand::{rngs::ThreadRng, Rng as _}; |
| 236 | + |
| 237 | + // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L73-L85 |
| 238 | + #[test] |
| 239 | + fn mid() { |
| 240 | + const N: usize = 1234; |
| 241 | + const M: usize = 2345; |
| 242 | + |
| 243 | + let mut rng = rand::thread_rng(); |
| 244 | + let mut gen_values = |n| gen_values::<Mod998244353>(&mut rng, n); |
| 245 | + let (a, b) = (gen_values(N), gen_values(M)); |
| 246 | + assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b)); |
| 247 | + } |
| 248 | + |
| 249 | + // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L87-L118 |
| 250 | + #[test] |
| 251 | + fn simple_s_mod() { |
| 252 | + const M1: u32 = 998_244_353; |
| 253 | + const M2: u32 = 924_844_033; |
| 254 | + |
| 255 | + modulus!(M1, M2); |
| 256 | + |
| 257 | + fn test<M: Modulus>(rng: &mut ThreadRng) { |
| 258 | + let mut gen_values = |n| gen_values::<Mod998244353>(rng, n); |
| 259 | + for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) { |
| 260 | + let (a, b) = (gen_values(n), gen_values(m)); |
| 261 | + assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b)); |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + let mut rng = rand::thread_rng(); |
| 266 | + test::<M1>(&mut rng); |
| 267 | + test::<M2>(&mut rng); |
| 268 | + } |
| 269 | + |
| 270 | + // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L358-L371 |
| 271 | + #[test] |
| 272 | + fn conv641() { |
| 273 | + const M: u32 = 641; |
| 274 | + modulus!(M); |
| 275 | + |
| 276 | + let mut rng = rand::thread_rng(); |
| 277 | + let mut gen_values = |n| gen_values::<M>(&mut rng, n); |
| 278 | + let (a, b) = (gen_values(64), gen_values(65)); |
| 279 | + assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b)); |
| 280 | + } |
| 281 | + |
| 282 | + // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L373-L386 |
| 283 | + #[test] |
| 284 | + fn conv18433() { |
| 285 | + const M: u32 = 18433; |
| 286 | + modulus!(M); |
| 287 | + |
| 288 | + let mut rng = rand::thread_rng(); |
| 289 | + let mut gen_values = |n| gen_values::<M>(&mut rng, n); |
| 290 | + let (a, b) = (gen_values(1024), gen_values(1025)); |
| 291 | + assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b)); |
| 292 | + } |
| 293 | + |
| 294 | + #[allow(clippy::many_single_char_names)] |
| 295 | + fn conv_naive<M: Modulus>( |
| 296 | + a: &[StaticModInt<M>], |
| 297 | + b: &[StaticModInt<M>], |
| 298 | + ) -> Vec<StaticModInt<M>> { |
| 299 | + let (n, m) = (a.len(), b.len()); |
| 300 | + let mut c = vec![StaticModInt::raw(0); n + m - 1]; |
| 301 | + for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) { |
| 302 | + c[i + j] += a[i] * b[j]; |
| 303 | + } |
| 304 | + c |
| 305 | + } |
| 306 | + |
| 307 | + fn gen_values<M: Modulus>(rng: &mut ThreadRng, n: usize) -> Vec<StaticModInt<M>> { |
| 308 | + (0..n) |
| 309 | + .map(|_| StaticModInt::raw(rng.gen_range(0, M::VALUE))) |
| 310 | + .collect() |
| 311 | + } |
| 312 | +} |
0 commit comments