9
9
namespace cp_algo ::math::fft {
10
10
template <modint_type base>
11
11
struct dft {
12
- int split;
13
12
cvector A, B;
14
13
static base factor, ifactor;
15
14
static bool init;
15
+ static int split;
16
16
17
17
dft (auto const & a, size_t n): A(n), B(n) {
18
18
if (!init) {
19
19
factor = 1 + random::rng () % (base::mod () - 1 );
20
+ split = int (std::sqrt (base::mod ())) + 1 ;
20
21
ifactor = base (1 ) / factor;
21
22
init = true ;
22
23
}
23
- split = int ( std::sqrt ( base::mod ())) + 1 ;
24
- base cur = 1 ;
24
+ base cur = factor ;
25
+ base step = bpow (factor, n) ;
25
26
cvector::exec_on_roots (2 * n, std::min (n, size (a)), [&](size_t i, auto rt) {
26
- auto splt = [&](size_t i) {
27
- auto ai = ftype (i < size (a) ? (a[i] * cur ).rem () : 0 );
27
+ auto splt = [&](size_t i, auto mul ) {
28
+ auto ai = ftype (i < size (a) ? (a[i] * mul ).rem () : 0 );
28
29
auto rem = std::remainder (ai, split);
29
30
auto quo = (ai - rem) / split;
30
31
return std::pair{rem, quo};
31
32
};
32
- auto [rai, qai] = splt (i);
33
- auto [rani, qani] = splt (n + i);
33
+ auto [rai, qai] = splt (i, cur );
34
+ auto [rani, qani] = splt (n + i, cur * step );
34
35
A.set (i, point (rai, rani) * rt);
35
36
B.set (i, point (qai, qani) * rt);
36
37
cur *= factor;
@@ -42,7 +43,7 @@ namespace cp_algo::math::fft {
42
43
}
43
44
}
44
45
45
- void mul (auto &&C, auto const & D, auto &res, size_t k, [[maybe_unused]] base ifactor ) {
46
+ void mul (auto &&C, auto const & D, auto &res, size_t k) {
46
47
assert (A.size () == C.size ());
47
48
size_t n = A.size ();
48
49
if (!n) {
@@ -83,7 +84,7 @@ namespace cp_algo::math::fft {
83
84
B.ifft ();
84
85
C.ifft ();
85
86
auto splitsplit = (base (split) * split).rem ();
86
- base cur = 1 ;
87
+ base cur = ifactor * ifactor ;
87
88
base step = bpow (ifactor, n);
88
89
cvector::exec_on_roots (2 * n, std::min (n, k), [&](size_t i, point rt) {
89
90
rt = conj (rt);
@@ -95,23 +96,22 @@ namespace cp_algo::math::fft {
95
96
int64_t A2 = llround (real (Bi));
96
97
res[i] = A0 + A1 * split + A2 * splitsplit;
97
98
res[i] *= cur;
98
- if (n + i >= k) {
99
- return ;
99
+ if (n + i < k) {
100
+ int64_t B0 = llround (imag (Ai));
101
+ int64_t B1 = llround (imag (Ci));
102
+ int64_t B2 = llround (imag (Bi));
103
+ res[n + i] = B0 + B1 * split + B2 * splitsplit;
104
+ res[n + i] *= cur * step;
100
105
}
101
- int64_t B0 = llround (imag (Ai));
102
- int64_t B1 = llround (imag (Ci));
103
- int64_t B2 = llround (imag (Bi));
104
- res[n + i] = B0 + B1 * split + B2 * splitsplit;
105
- res[n + i] *= cur * step;
106
106
cur *= ifactor;
107
107
});
108
108
checkpoint (" recover mod" );
109
109
}
110
110
void mul_inplace (auto &&B, auto & res, size_t k) {
111
- mul (B.A , B.B , res, k, ifactor * B. ifactor );
111
+ mul (B.A , B.B , res, k);
112
112
}
113
113
void mul (auto const & B, auto & res, size_t k) {
114
- mul (cvector (B.A ), B.B , res, k, ifactor * B. ifactor );
114
+ mul (cvector (B.A ), B.B , res, k);
115
115
}
116
116
std::vector<base> operator *= (dft &B) {
117
117
std::vector<base> res (2 * A.size ());
@@ -132,6 +132,7 @@ namespace cp_algo::math::fft {
132
132
template <modint_type base> base dft<base>::factor = 1 ;
133
133
template <modint_type base> base dft<base>::ifactor = 1 ;
134
134
template <modint_type base> bool dft<base>::init = false ;
135
+ template <modint_type base> int dft<base>::split = 1 ;
135
136
136
137
void mul_slow (auto &a, auto const & b, size_t k) {
137
138
if (empty (a) || empty (b)) {
0 commit comments