4
4
#include < iostream>
5
5
#include < cassert>
6
6
namespace cp_algo ::math {
7
- inline constexpr auto inv2 (auto x) {
8
- assert (x % 2 );
9
- std::make_unsigned_t <decltype (x)> y = 1 ;
10
- while (y * x != 1 ) {
11
- y *= 2 - x * y;
12
- }
13
- return y;
14
- }
15
7
16
8
template <typename modint, typename _Int>
17
9
struct modint_base {
@@ -23,97 +15,76 @@ namespace cp_algo::math {
23
15
static Int mod () {
24
16
return modint::mod ();
25
17
}
26
- static UInt imod () {
27
- return modint::imod ();
18
+ static Int remod () {
19
+ return modint::remod ();
28
20
}
29
- static UInt2 pw128 () {
30
- return modint::pw128 ();
31
- }
32
- static UInt m_reduce (UInt2 ab) {
33
- if (mod () % 2 == 0 ) [[unlikely]] {
34
- return UInt (ab % mod ());
35
- } else {
36
- UInt2 m = (UInt)ab * imod ();
37
- return UInt ((ab + m * mod ()) >> bits);
38
- }
39
- }
40
- static UInt m_reduce (Int2 ab) {
41
- return m_reduce (UInt2 (ab + UInt2 (ab < 0 ) * mod () * mod ()));
42
- }
43
- static UInt m_transform (UInt a) {
44
- if (mod () % 2 == 0 ) [[unlikely]] {
45
- return a;
46
- } else {
47
- return m_reduce (a * pw128 ());
48
- }
21
+ static UInt2 modmod () {
22
+ return UInt2 (mod ()) * mod ();
49
23
}
50
24
modint_base (): r(0 ) {}
51
- modint_base (Int2 rr): r(UInt(rr % mod())) {
52
- r = std::min (r, r + mod ());
53
- r = m_transform (r);
25
+ modint_base (Int2 rr) {
26
+ to_modint ().setr (UInt ((rr + modmod ()) % mod ()));
54
27
}
55
28
modint inv () const {
56
29
return bpow (to_modint (), mod () - 2 );
57
30
}
58
31
modint operator - () const {
59
32
modint neg;
60
- neg.r = std::min (-r, 2 * mod () - r);
33
+ neg.r = std::min (-r, remod () - r);
61
34
return neg;
62
35
}
63
36
modint& operator /= (const modint &t) {
64
37
return to_modint () *= t.inv ();
65
38
}
66
39
modint& operator *= (const modint &t) {
67
- r = m_reduce (( UInt2)r * t.r );
40
+ r = UInt ( UInt2 (r) * t.r % mod () );
68
41
return to_modint ();
69
42
}
70
43
modint& operator += (const modint &t) {
71
- r += t.r ; r = std::min (r, r - 2 * mod ());
44
+ r += t.r ; r = std::min (r, r - remod ());
72
45
return to_modint ();
73
46
}
74
47
modint& operator -= (const modint &t) {
75
- r -= t.r ; r = std::min (r, r + 2 * mod ());
48
+ r -= t.r ; r = std::min (r, r + remod ());
76
49
return to_modint ();
77
50
}
78
51
modint operator + (const modint &t) const {return modint (to_modint ()) += t;}
79
52
modint operator - (const modint &t) const {return modint (to_modint ()) -= t;}
80
53
modint operator * (const modint &t) const {return modint (to_modint ()) *= t;}
81
54
modint operator / (const modint &t) const {return modint (to_modint ()) /= t;}
82
55
// Why <=> doesn't work?..
83
- auto operator == (const modint_base &t) const {return getr () == t.getr ();}
84
- auto operator != (const modint_base &t) const {return getr () != t.getr ();}
85
- auto operator <= (const modint_base &t) const {return getr () <= t.getr ();}
86
- auto operator >= (const modint_base &t) const {return getr () >= t.getr ();}
87
- auto operator < (const modint_base &t) const {return getr () < t.getr ();}
88
- auto operator > (const modint_base &t) const {return getr () > t.getr ();}
56
+ auto operator == (const modint &t) const {return to_modint (). getr () == t.getr ();}
57
+ auto operator != (const modint &t) const {return to_modint (). getr () != t.getr ();}
58
+ auto operator <= (const modint &t) const {return to_modint (). getr () <= t.getr ();}
59
+ auto operator >= (const modint &t) const {return to_modint (). getr () >= t.getr ();}
60
+ auto operator < (const modint &t) const {return to_modint (). getr () < t.getr ();}
61
+ auto operator > (const modint &t) const {return to_modint (). getr () > t.getr ();}
89
62
Int rem () const {
90
- UInt R = getr ();
63
+ UInt R = to_modint (). getr ();
91
64
return 2 * R > (UInt)mod () ? R - mod () : R;
92
65
}
66
+ void setr (UInt rr) {
67
+ r = rr;
68
+ }
69
+ UInt getr () const {
70
+ return r;
71
+ }
93
72
94
- // Only use if you really know what you're doing!
95
- static UInt modmod () {return ( UInt) 8 * mod () * mod ( );};
73
+ // Only use these if you really know what you're doing!
74
+ static UInt modmod8 () {return UInt ( 8 * modmod () );}
96
75
void add_unsafe (UInt t) {r += t;}
97
- void pseudonormalize () {r = std::min (r, r - modmod ());}
76
+ void pseudonormalize () {r = std::min (r, r - modmod8 ());}
98
77
modint const & normalize () {
99
78
if (r >= (UInt)mod ()) {
100
79
r %= mod ();
101
80
}
102
81
return to_modint ();
103
82
}
104
- void setr (UInt rr) {r = m_transform (rr);}
105
- UInt getr () const {
106
- UInt res = m_reduce (UInt2 (r));
107
- return std::min (res, res - mod ());
108
- }
109
83
void setr_direct (UInt rr) {r = rr;}
110
84
UInt getr_direct () const {return r;}
111
- Int rem_direct () const {
112
- UInt R = std::min (r, r - mod ());
113
- return 2 * R > (UInt)mod () ? R - mod () : R;
114
- }
115
- private:
85
+ protected:
116
86
UInt r;
87
+ private:
117
88
modint& to_modint () {return static_cast <modint&>(*this );}
118
89
modint const & to_modint () const {return static_cast <modint const &>(*this );}
119
90
};
@@ -135,18 +106,53 @@ namespace cp_algo::math {
135
106
struct modint : modint_base<modint<m>, decltype (m)> {
136
107
using Base = modint_base<modint<m>, decltype (m)>;
137
108
using Base::Base;
138
- static constexpr Base::UInt im = m % 2 ? inv2(-m) : 0 ;
139
- static constexpr Base::UInt r2 = (typename Base::UInt2)(-1 ) % m + 1 ;
140
109
static constexpr Base::Int mod () {return m;}
141
- static constexpr Base::UInt imod () {return im ;}
142
- static constexpr Base::UInt2 pw128 () {return r2 ;}
110
+ static constexpr Base::UInt remod () {return m ;}
111
+ auto getr () const {return Base::r ;}
143
112
};
144
113
114
+ inline constexpr auto inv2 (auto x) {
115
+ assert (x % 2 );
116
+ std::make_unsigned_t <decltype (x)> y = 1 ;
117
+ while (y * x != 1 ) {
118
+ y *= 2 - x * y;
119
+ }
120
+ return y;
121
+ }
122
+
145
123
template <typename Int = int64_t >
146
124
struct dynamic_modint : modint_base<dynamic_modint<Int>, Int> {
147
125
using Base = modint_base<dynamic_modint<Int>, Int>;
148
126
using Base::Base;
127
+
128
+ static Base::UInt m_reduce (Base::UInt2 ab) {
129
+ if (mod () % 2 == 0 ) [[unlikely]] {
130
+ return typename Base::UInt (ab % mod ());
131
+ } else {
132
+ typename Base::UInt2 m = typename Base::UInt (ab) * imod ();
133
+ return typename Base::UInt ((ab + m * mod ()) >> Base::bits);
134
+ }
135
+ }
136
+ static Base::UInt m_transform (Base::UInt a) {
137
+ if (mod () % 2 == 0 ) [[unlikely]] {
138
+ return a;
139
+ } else {
140
+ return m_reduce (a * pw128 ());
141
+ }
142
+ }
143
+ dynamic_modint& operator *= (const dynamic_modint &t) {
144
+ Base::r = m_reduce (typename Base::UInt2 (Base::r) * t.r );
145
+ return *this ;
146
+ }
147
+ void setr (Base::UInt rr) {
148
+ Base::r = m_transform (rr);
149
+ }
150
+ Base::UInt getr () const {
151
+ typename Base::UInt res = m_reduce (Base::r);
152
+ return std::min (res, res - mod ());
153
+ }
149
154
static Int mod () {return m;}
155
+ static Int remod () {return 2 * m;}
150
156
static Base::UInt imod () {return im;}
151
157
static Base::UInt2 pw128 () {return r2;}
152
158
static void switch_mod (Int nm) {
0 commit comments