Skip to content

Commit 6508473

Browse files
committed
fix
1 parent 0ba3e06 commit 6508473

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

cp-algo/math/fft.hpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,29 @@
99
namespace cp_algo::math::fft {
1010
template<modint_type base>
1111
struct dft {
12-
int split;
1312
cvector A, B;
1413
static base factor, ifactor;
1514
static bool init;
15+
static int split;
1616

1717
dft(auto const& a, size_t n): A(n), B(n) {
1818
if(!init) {
1919
factor = 1 + random::rng() % (base::mod() - 1);
20+
split = int(std::sqrt(base::mod())) + 1;
2021
ifactor = base(1) / factor;
2122
init = true;
2223
}
23-
split = int(std::sqrt(base::mod())) + 1;
24-
base cur = 1;
24+
base cur = factor;
25+
base step = bpow(factor, n);
2526
cvector::exec_on_roots(2 * n, std::min(n, size(a)), [&](size_t i, auto rt) {
26-
auto splt = [&](size_t i) {
27-
auto ai = ftype(i < size(a) ? (a[i] * cur).rem() : 0);
27+
auto splt = [&](size_t i, auto mul) {
28+
auto ai = ftype(i < size(a) ? (a[i] * mul).rem() : 0);
2829
auto rem = std::remainder(ai, split);
2930
auto quo = (ai - rem) / split;
3031
return std::pair{rem, quo};
3132
};
32-
auto [rai, qai] = splt(i);
33-
auto [rani, qani] = splt(n + i);
33+
auto [rai, qai] = splt(i, cur);
34+
auto [rani, qani] = splt(n + i, cur * step);
3435
A.set(i, point(rai, rani) * rt);
3536
B.set(i, point(qai, qani) * rt);
3637
cur *= factor;
@@ -42,7 +43,7 @@ namespace cp_algo::math::fft {
4243
}
4344
}
4445

45-
void mul(auto &&C, auto const& D, auto &res, size_t k, [[maybe_unused]] base ifactor) {
46+
void mul(auto &&C, auto const& D, auto &res, size_t k) {
4647
assert(A.size() == C.size());
4748
size_t n = A.size();
4849
if(!n) {
@@ -83,7 +84,7 @@ namespace cp_algo::math::fft {
8384
B.ifft();
8485
C.ifft();
8586
auto splitsplit = (base(split) * split).rem();
86-
base cur = 1;
87+
base cur = ifactor * ifactor;
8788
base step = bpow(ifactor, n);
8889
cvector::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) {
8990
rt = conj(rt);
@@ -95,23 +96,22 @@ namespace cp_algo::math::fft {
9596
int64_t A2 = llround(real(Bi));
9697
res[i] = A0 + A1 * split + A2 * splitsplit;
9798
res[i] *= cur;
98-
if(n + i >= k) {
99-
return;
99+
if(n + i < k) {
100+
int64_t B0 = llround(imag(Ai));
101+
int64_t B1 = llround(imag(Ci));
102+
int64_t B2 = llround(imag(Bi));
103+
res[n + i] = B0 + B1 * split + B2 * splitsplit;
104+
res[n + i] *= cur * step;
100105
}
101-
int64_t B0 = llround(imag(Ai));
102-
int64_t B1 = llround(imag(Ci));
103-
int64_t B2 = llround(imag(Bi));
104-
res[n + i] = B0 + B1 * split + B2 * splitsplit;
105-
res[n + i] *= cur * step;
106106
cur *= ifactor;
107107
});
108108
checkpoint("recover mod");
109109
}
110110
void mul_inplace(auto &&B, auto& res, size_t k) {
111-
mul(B.A, B.B, res, k, ifactor * B.ifactor);
111+
mul(B.A, B.B, res, k);
112112
}
113113
void mul(auto const& B, auto& res, size_t k) {
114-
mul(cvector(B.A), B.B, res, k, ifactor * B.ifactor);
114+
mul(cvector(B.A), B.B, res, k);
115115
}
116116
std::vector<base> operator *= (dft &B) {
117117
std::vector<base> res(2 * A.size());
@@ -132,6 +132,7 @@ namespace cp_algo::math::fft {
132132
template<modint_type base> base dft<base>::factor = 1;
133133
template<modint_type base> base dft<base>::ifactor = 1;
134134
template<modint_type base> bool dft<base>::init = false;
135+
template<modint_type base> int dft<base>::split = 1;
135136

136137
void mul_slow(auto &a, auto const& b, size_t k) {
137138
if(empty(a) || empty(b)) {

0 commit comments

Comments
 (0)