Skip to content

Commit 20dbb1f

Browse files
committed
fix dynamic_modint convolution + simplify root calculation
1 parent f273d92 commit 20dbb1f

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

cp-algo/math/cvector.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ namespace cp_algo::math::fft {
6969
size_t size() const {
7070
return flen * r.size();
7171
}
72-
static size_t eval_arg(size_t n) {
72+
static constexpr size_t eval_arg(size_t n) {
7373
if(n < pre_evals) {
7474
return eval_args[n];
7575
} else {
7676
return eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1);
7777
}
7878
}
79-
static point eval_point(size_t n) {
79+
static constexpr point eval_point(size_t n) {
8080
if(n % 2) {
8181
return -eval_point(n - 1);
8282
} else if(n % 4) {
@@ -87,8 +87,8 @@ namespace cp_algo::math::fft {
8787
return polar(1., std::numbers::pi / (ftype)std::bit_floor(n) * (ftype)eval_arg(n));
8888
}
8989
}
90-
static point root(size_t n) {
91-
return polar(1., 2. * std::numbers::pi / (ftype)n);
90+
static constexpr point root(size_t n) {
91+
return eval_point(n / 2);
9292
}
9393
template<int step>
9494
static void exec_on_evals(size_t n, auto &&callback) {
@@ -97,6 +97,11 @@ namespace cp_algo::math::fft {
9797
callback(i, factor * eval_point(step * i));
9898
}
9999
}
100+
template<int step>
101+
static void exec_on_eval(size_t n, size_t k, auto &&callback) {
102+
point factor = root(4 * step * n);
103+
callback(factor * eval_point(step * k));
104+
}
100105

101106
void dot(cvector const& t) {
102107
size_t n = this->size();

cp-algo/math/fft.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ namespace cp_algo::math::fft {
101101
auto [Cx, Cy] = C.at(i);
102102
auto set_i = [&](size_t i, auto A, auto B, auto C, auto mul) {
103103
auto A0 = lround(A), A1 = lround(C), A2 = lround(B);
104-
auto Ai = A0 + A1 * split + A2 * splitsplit + base::modmod();
104+
auto Ai = A0 + A1 * split + A2 * splitsplit + uint64_t(base::modmod());
105105
auto Au = montgomery_reduce(u64x4(Ai), mod, imod);
106106
Au = montgomery_mul(Au, mul, mod, imod);
107107
Au = Au >= base::mod() ? Au - base::mod() : Au;

0 commit comments

Comments
 (0)