Skip to content

Commit 7045e1d

Browse files
committed
Improve fft-mod init
1 parent 23aabb1 commit 7045e1d

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

cp-algo/math/fft.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ namespace cp_algo::math::fft {
2525
base step = bpow(factor, n);
2626
cvector::exec_on_roots(2 * n, std::min(n, size(a)), [&](size_t i, auto rt) {
2727
auto splt = [&](size_t i, auto mul) {
28-
auto ai = ftype(i < size(a) ? (a[i] * mul).rem() : 0);
29-
auto rem = std::remainder(ai, split);
28+
auto ai = i < size(a) ? (a[i] * mul).rem_direct() : 0;
29+
auto rem = ai % split;
3030
auto quo = (ai - rem) / split;
3131
return std::pair{rem, quo};
3232
};
@@ -94,13 +94,13 @@ namespace cp_algo::math::fft {
9494
int64_t A0 = llround(real(Ai));
9595
int64_t A1 = llround(real(Ci));
9696
int64_t A2 = llround(real(Bi));
97-
res[i] = A0 + A1 * split + A2 * splitsplit;
97+
res[i].setr_direct(base::m_reduce(A0 + A1 * split + A2 * splitsplit));
9898
res[i] *= cur;
9999
if(n + i < k) {
100100
int64_t B0 = llround(imag(Ai));
101101
int64_t B1 = llround(imag(Ci));
102102
int64_t B2 = llround(imag(Bi));
103-
res[n + i] = B0 + B1 * split + B2 * splitsplit;
103+
res[n + i].setr_direct(base::m_reduce(B0 + B1 * split + B2 * splitsplit));
104104
res[n + i] *= cur * step;
105105
}
106106
cur *= ifactor;

cp-algo/number_theory/modint.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ namespace cp_algo::math {
3737
return UInt((ab + m * mod()) >> bits);
3838
}
3939
}
40+
static UInt m_reduce(Int2 ab) {
41+
return m_reduce(UInt2(ab + UInt2(ab < 0) * mod() * mod()));
42+
}
4043
static UInt m_transform(UInt a) {
4144
if(mod() % 2 == 0) [[unlikely]] {
4245
return a;
@@ -89,7 +92,7 @@ namespace cp_algo::math {
8992
}
9093

9194
// Only use if you really know what you're doing!
92-
UInt modmod() const {return (UInt)8 * mod() * mod();};
95+
static UInt modmod() {return (UInt)8 * mod() * mod();};
9396
void add_unsafe(UInt t) {r += t;}
9497
void pseudonormalize() {r = std::min(r, r - modmod());}
9598
modint const& normalize() {
@@ -100,11 +103,15 @@ namespace cp_algo::math {
100103
}
101104
void setr(UInt rr) {r = m_transform(rr);}
102105
UInt getr() const {
103-
UInt res = m_reduce(r);
106+
UInt res = m_reduce(UInt2(r));
104107
return std::min(res, res - mod());
105108
}
106109
void setr_direct(UInt rr) {r = rr;}
107110
UInt getr_direct() const {return r;}
111+
Int rem_direct() const {
112+
UInt R = std::min(r, r - mod());
113+
return 2 * R > (UInt)mod() ? R - mod() : R;
114+
}
108115
private:
109116
UInt r;
110117
modint& to_modint() {return static_cast<modint&>(*this);}

0 commit comments

Comments
 (0)