Skip to content

Commit 4524c31

Browse files
committed
Add convolution_raw
1 parent 2cee3c9 commit 4524c31

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

src/convolution.rs

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
use crate::{
22
internal_bit, internal_math,
3-
modint::{ButterflyCache, Modulus, StaticModInt},
3+
modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt},
4+
};
5+
use std::{
6+
cell::RefCell,
7+
cmp,
8+
convert::{TryFrom, TryInto as _},
9+
fmt,
10+
thread::LocalKey,
411
};
5-
use std::{cell::RefCell, cmp, thread::LocalKey};
612

713
#[allow(clippy::many_single_char_names)]
814
pub fn convolution<M: Modulus>(
@@ -43,6 +49,26 @@ pub fn convolution<M: Modulus>(
4349
a
4450
}
4551

52+
pub fn convolution_raw<
53+
T: RemEuclidU32 + TryFrom<u32, Error = E> + Clone,
54+
E: fmt::Debug,
55+
M: Modulus,
56+
>(
57+
a: &[T],
58+
b: &[T],
59+
) -> Vec<T> {
60+
let a = a.iter().cloned().map(Into::into).collect::<Vec<_>>();
61+
let b = b.iter().cloned().map(Into::into).collect::<Vec<_>>();
62+
convolution::<M>(&a, &b)
63+
.into_iter()
64+
.map(|z| {
65+
z.val()
66+
.try_into()
67+
.expect("the numeric type is smaller than the modulus")
68+
})
69+
.collect()
70+
}
71+
4672
#[allow(clippy::many_single_char_names)]
4773
pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
4874
const M1: u64 = 754_974_721; // 2^24
@@ -84,17 +110,9 @@ pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
84110
let i2 = internal_math::inv_gcd(M1M3 as _, M2 as _).1;
85111
let i3 = internal_math::inv_gcd(M1M2 as _, M3 as _).1;
86112

87-
let (c1, c2, c3) = {
88-
fn c<M: Modulus>(a: &[i64], b: &[i64]) -> Vec<i64> {
89-
let a = a.iter().copied().map(Into::into).collect::<Vec<_>>();
90-
let b = b.iter().copied().map(Into::into).collect::<Vec<_>>();
91-
convolution::<M>(&a, &b)
92-
.into_iter()
93-
.map(|z| z.val().into())
94-
.collect()
95-
}
96-
(c::<M1>(a, b), c::<M2>(a, b), c::<M3>(a, b))
97-
};
113+
let c1 = convolution_raw::<i64, _, M1>(a, b);
114+
let c2 = convolution_raw::<i64, _, M2>(a, b);
115+
let c3 = convolution_raw::<i64, _, M3>(a, b);
98116

99117
c1.into_iter()
100118
.zip(c2)

0 commit comments

Comments
 (0)