Skip to content

Commit b61487e

Browse files
committed
Move Montgomery to dynamic_modint
1 parent 4021d95 commit b61487e

File tree

2 files changed

+70
-64
lines changed

2 files changed

+70
-64
lines changed

cp-algo/math/fft.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace cp_algo::math::fft {
2626
base step = bpow(factor, n);
2727
for(size_t i = 0; i < std::min(n, size(a)); i++) {
2828
auto splt = [&](size_t i, auto mul) {
29-
auto ai = i < size(a) ? (a[i] * mul).rem_direct() : 0;
29+
auto ai = i < size(a) ? (a[i] * mul).rem() : 0;
3030
auto rem = ai % split;
3131
auto quo = ai / split;
3232
return std::pair{(ftype)rem, (ftype)quo};
@@ -95,13 +95,13 @@ namespace cp_algo::math::fft {
9595
Int2 A0 = llround(Ax);
9696
Int2 A1 = llround(Cx);
9797
Int2 A2 = llround(Bx);
98-
res[i].setr_direct(base::m_reduce(A0 + A1 * split + A2 * splitsplit));
98+
res[i] = A0 + A1 * split + A2 * splitsplit;
9999
res[i] *= cur;
100100
if(n + i < k) {
101101
Int2 B0 = llround(Ay);
102102
Int2 B1 = llround(Cy);
103103
Int2 B2 = llround(By);
104-
res[n + i].setr_direct(base::m_reduce(B0 + B1 * split + B2 * splitsplit));
104+
res[n + i] = B0 + B1 * split + B2 * splitsplit;
105105
res[n + i] *= cur * step;
106106
}
107107
cur *= ifactor;

cp-algo/number_theory/modint.hpp

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@
44
#include <iostream>
55
#include <cassert>
66
namespace cp_algo::math {
7-
inline constexpr auto inv2(auto x) {
8-
assert(x % 2);
9-
std::make_unsigned_t<decltype(x)> y = 1;
10-
while(y * x != 1) {
11-
y *= 2 - x * y;
12-
}
13-
return y;
14-
}
157

168
template<typename modint, typename _Int>
179
struct modint_base {
@@ -23,97 +15,76 @@ namespace cp_algo::math {
2315
static Int mod() {
2416
return modint::mod();
2517
}
26-
static UInt imod() {
27-
return modint::imod();
18+
static Int remod() {
19+
return modint::remod();
2820
}
29-
static UInt2 pw128() {
30-
return modint::pw128();
31-
}
32-
static UInt m_reduce(UInt2 ab) {
33-
if(mod() % 2 == 0) [[unlikely]] {
34-
return UInt(ab % mod());
35-
} else {
36-
UInt2 m = (UInt)ab * imod();
37-
return UInt((ab + m * mod()) >> bits);
38-
}
39-
}
40-
static UInt m_reduce(Int2 ab) {
41-
return m_reduce(UInt2(ab + UInt2(ab < 0) * mod() * mod()));
42-
}
43-
static UInt m_transform(UInt a) {
44-
if(mod() % 2 == 0) [[unlikely]] {
45-
return a;
46-
} else {
47-
return m_reduce(a * pw128());
48-
}
21+
static UInt2 modmod() {
22+
return UInt2(mod()) * mod();
4923
}
5024
modint_base(): r(0) {}
51-
modint_base(Int2 rr): r(UInt(rr % mod())) {
52-
r = std::min(r, r + mod());
53-
r = m_transform(r);
25+
modint_base(Int2 rr) {
26+
to_modint().setr(UInt((rr + modmod()) % mod()));
5427
}
5528
modint inv() const {
5629
return bpow(to_modint(), mod() - 2);
5730
}
5831
modint operator - () const {
5932
modint neg;
60-
neg.r = std::min(-r, 2 * mod() - r);
33+
neg.r = std::min(-r, remod() - r);
6134
return neg;
6235
}
6336
modint& operator /= (const modint &t) {
6437
return to_modint() *= t.inv();
6538
}
6639
modint& operator *= (const modint &t) {
67-
r = m_reduce((UInt2)r * t.r);
40+
r = UInt(UInt2(r) * t.r % mod());
6841
return to_modint();
6942
}
7043
modint& operator += (const modint &t) {
71-
r += t.r; r = std::min(r, r - 2 * mod());
44+
r += t.r; r = std::min(r, r - remod());
7245
return to_modint();
7346
}
7447
modint& operator -= (const modint &t) {
75-
r -= t.r; r = std::min(r, r + 2 * mod());
48+
r -= t.r; r = std::min(r, r + remod());
7649
return to_modint();
7750
}
7851
modint operator + (const modint &t) const {return modint(to_modint()) += t;}
7952
modint operator - (const modint &t) const {return modint(to_modint()) -= t;}
8053
modint operator * (const modint &t) const {return modint(to_modint()) *= t;}
8154
modint operator / (const modint &t) const {return modint(to_modint()) /= t;}
8255
// Why <=> doesn't work?..
83-
auto operator == (const modint_base &t) const {return getr() == t.getr();}
84-
auto operator != (const modint_base &t) const {return getr() != t.getr();}
85-
auto operator <= (const modint_base &t) const {return getr() <= t.getr();}
86-
auto operator >= (const modint_base &t) const {return getr() >= t.getr();}
87-
auto operator < (const modint_base &t) const {return getr() < t.getr();}
88-
auto operator > (const modint_base &t) const {return getr() > t.getr();}
56+
auto operator == (const modint &t) const {return to_modint().getr() == t.getr();}
57+
auto operator != (const modint &t) const {return to_modint().getr() != t.getr();}
58+
auto operator <= (const modint &t) const {return to_modint().getr() <= t.getr();}
59+
auto operator >= (const modint &t) const {return to_modint().getr() >= t.getr();}
60+
auto operator < (const modint &t) const {return to_modint().getr() < t.getr();}
61+
auto operator > (const modint &t) const {return to_modint().getr() > t.getr();}
8962
Int rem() const {
90-
UInt R = getr();
63+
UInt R = to_modint().getr();
9164
return 2 * R > (UInt)mod() ? R - mod() : R;
9265
}
66+
void setr(UInt rr) {
67+
r = rr;
68+
}
69+
UInt getr() const {
70+
return r;
71+
}
9372

94-
// Only use if you really know what you're doing!
95-
static UInt modmod() {return (UInt)8 * mod() * mod();};
73+
// Only use these if you really know what you're doing!
74+
static UInt modmod8() {return UInt(8 * modmod());}
9675
void add_unsafe(UInt t) {r += t;}
97-
void pseudonormalize() {r = std::min(r, r - modmod());}
76+
void pseudonormalize() {r = std::min(r, r - modmod8());}
9877
modint const& normalize() {
9978
if(r >= (UInt)mod()) {
10079
r %= mod();
10180
}
10281
return to_modint();
10382
}
104-
void setr(UInt rr) {r = m_transform(rr);}
105-
UInt getr() const {
106-
UInt res = m_reduce(UInt2(r));
107-
return std::min(res, res - mod());
108-
}
10983
void setr_direct(UInt rr) {r = rr;}
11084
UInt getr_direct() const {return r;}
111-
Int rem_direct() const {
112-
UInt R = std::min(r, r - mod());
113-
return 2 * R > (UInt)mod() ? R - mod() : R;
114-
}
115-
private:
85+
protected:
11686
UInt r;
87+
private:
11788
modint& to_modint() {return static_cast<modint&>(*this);}
11889
modint const& to_modint() const {return static_cast<modint const&>(*this);}
11990
};
@@ -135,18 +106,53 @@ namespace cp_algo::math {
135106
struct modint: modint_base<modint<m>, decltype(m)> {
136107
using Base = modint_base<modint<m>, decltype(m)>;
137108
using Base::Base;
138-
static constexpr Base::UInt im = m % 2 ? inv2(-m) : 0;
139-
static constexpr Base::UInt r2 = (typename Base::UInt2)(-1) % m + 1;
140109
static constexpr Base::Int mod() {return m;}
141-
static constexpr Base::UInt imod() {return im;}
142-
static constexpr Base::UInt2 pw128() {return r2;}
110+
static constexpr Base::UInt remod() {return m;}
111+
auto getr() const {return Base::r;}
143112
};
144113

114+
inline constexpr auto inv2(auto x) {
115+
assert(x % 2);
116+
std::make_unsigned_t<decltype(x)> y = 1;
117+
while(y * x != 1) {
118+
y *= 2 - x * y;
119+
}
120+
return y;
121+
}
122+
145123
template<typename Int = int64_t>
146124
struct dynamic_modint: modint_base<dynamic_modint<Int>, Int> {
147125
using Base = modint_base<dynamic_modint<Int>, Int>;
148126
using Base::Base;
127+
128+
static Base::UInt m_reduce(Base::UInt2 ab) {
129+
if(mod() % 2 == 0) [[unlikely]] {
130+
return typename Base::UInt(ab % mod());
131+
} else {
132+
typename Base::UInt2 m = typename Base::UInt(ab) * imod();
133+
return typename Base::UInt((ab + m * mod()) >> Base::bits);
134+
}
135+
}
136+
static Base::UInt m_transform(Base::UInt a) {
137+
if(mod() % 2 == 0) [[unlikely]] {
138+
return a;
139+
} else {
140+
return m_reduce(a * pw128());
141+
}
142+
}
143+
dynamic_modint& operator *= (const dynamic_modint &t) {
144+
Base::r = m_reduce(typename Base::UInt2(Base::r) * t.r);
145+
return *this;
146+
}
147+
void setr(Base::UInt rr) {
148+
Base::r = m_transform(rr);
149+
}
150+
Base::UInt getr() const {
151+
typename Base::UInt res = m_reduce(Base::r);
152+
return std::min(res, res - mod());
153+
}
149154
static Int mod() {return m;}
155+
static Int remod() {return 2 * m;}
150156
static Base::UInt imod() {return im;}
151157
static Base::UInt2 pw128() {return r2;}
152158
static void switch_mod(Int nm) {

0 commit comments

Comments
 (0)