|
1 | 1 | use crate::{
|
2 | 2 | internal_bit, internal_math,
|
3 |
| - modint::{ButterflyCache, Modulus, StaticModInt}, |
| 3 | + modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt}, |
| 4 | +}; |
| 5 | +use std::{ |
| 6 | + cell::RefCell, |
| 7 | + cmp, |
| 8 | + convert::{TryFrom, TryInto as _}, |
| 9 | + fmt, |
| 10 | + thread::LocalKey, |
4 | 11 | };
|
5 |
| -use std::{cell::RefCell, cmp, thread::LocalKey}; |
6 | 12 |
|
7 | 13 | #[allow(clippy::many_single_char_names)]
|
8 | 14 | pub fn convolution<M: Modulus>(
|
@@ -43,6 +49,26 @@ pub fn convolution<M: Modulus>(
|
43 | 49 | a
|
44 | 50 | }
|
45 | 51 |
|
| 52 | +pub fn convolution_raw< |
| 53 | + T: RemEuclidU32 + TryFrom<u32, Error = E> + Clone, |
| 54 | + E: fmt::Debug, |
| 55 | + M: Modulus, |
| 56 | +>( |
| 57 | + a: &[T], |
| 58 | + b: &[T], |
| 59 | +) -> Vec<T> { |
| 60 | + let a = a.iter().cloned().map(Into::into).collect::<Vec<_>>(); |
| 61 | + let b = b.iter().cloned().map(Into::into).collect::<Vec<_>>(); |
| 62 | + convolution::<M>(&a, &b) |
| 63 | + .into_iter() |
| 64 | + .map(|z| { |
| 65 | + z.val() |
| 66 | + .try_into() |
| 67 | + .expect("the numeric type is smaller than the modulus") |
| 68 | + }) |
| 69 | + .collect() |
| 70 | +} |
| 71 | + |
46 | 72 | #[allow(clippy::many_single_char_names)]
|
47 | 73 | pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
|
48 | 74 | const M1: u64 = 754_974_721; // 2^24
|
@@ -84,17 +110,9 @@ pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
|
84 | 110 | let i2 = internal_math::inv_gcd(M1M3 as _, M2 as _).1;
|
85 | 111 | let i3 = internal_math::inv_gcd(M1M2 as _, M3 as _).1;
|
86 | 112 |
|
87 |
| - let (c1, c2, c3) = { |
88 |
| - fn c<M: Modulus>(a: &[i64], b: &[i64]) -> Vec<i64> { |
89 |
| - let a = a.iter().copied().map(Into::into).collect::<Vec<_>>(); |
90 |
| - let b = b.iter().copied().map(Into::into).collect::<Vec<_>>(); |
91 |
| - convolution::<M>(&a, &b) |
92 |
| - .into_iter() |
93 |
| - .map(|z| z.val().into()) |
94 |
| - .collect() |
95 |
| - } |
96 |
| - (c::<M1>(a, b), c::<M2>(a, b), c::<M3>(a, b)) |
97 |
| - }; |
| 113 | + let c1 = convolution_raw::<i64, _, M1>(a, b); |
| 114 | + let c2 = convolution_raw::<i64, _, M2>(a, b); |
| 115 | + let c3 = convolution_raw::<i64, _, M3>(a, b); |
98 | 116 |
|
99 | 117 | c1.into_iter()
|
100 | 118 | .zip(c2)
|
|
0 commit comments