Skip to content

Commit d6de241

Browse files
committed
Make fft work over x^n-i
1 parent a2e17f8 commit d6de241

File tree

3 files changed

+49
-98
lines changed

3 files changed

+49
-98
lines changed

cp-algo/math/cvector.hpp

+30-69
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,6 @@ namespace cp_algo::math::fft {
2424
r.resize(n / flen);
2525
checkpoint("cvector create");
2626
}
27-
cvector(cvector const& t) {
28-
r.resize(t.r.size());
29-
for(size_t i = 0; i < r.size(); i++) {
30-
r[i] = {vftype(t.r[i].real()), vftype(t.r[i].imag())};
31-
}
32-
checkpoint("cvector copy");
33-
}
34-
cvector(cvector&& t) = delete;
3527

3628
vpoint& at(size_t k) {return r[k / flen];}
3729
vpoint at(size_t k) const {return r[k / flen];}
@@ -63,74 +55,53 @@ namespace cp_algo::math::fft {
6355
return eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1);
6456
}
6557
}
66-
static auto root(size_t n, size_t k) {
67-
if(n < pre_roots) {
68-
return roots[n + k];
69-
} else if (k % 2 == 0) {
70-
return root(n / 2, k / 2);
71-
} else {
72-
return polar(1., std::numbers::pi / (ftype)n * (ftype)k);
73-
}
74-
}
7558
static point eval_point(size_t n) {
7659
if(n % 2) {
77-
return eval_point(n - 1) * point(0, 1);
78-
} else if(n / 2 < pre_evals) {
79-
return evalp[n / 2];
60+
return -eval_point(n - 1);
61+
} else if(n % 4) {
62+
return eval_point(n - 2) * point(0, 1);
63+
} else if(n / 4 < pre_evals) {
64+
return evalp[n / 4];
8065
} else {
81-
return root(2 * std::bit_floor(n), eval_arg(n));
66+
return polar(1., std::numbers::pi / (ftype)std::bit_floor(n) * (ftype)eval_arg(n));
8267
}
8368
}
84-
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
85-
point cur = {1, 0};
86-
point arg = root(n, 1);
87-
for(size_t i = 0; i < m; i++) {
88-
callback(i, cur);
89-
if(i % 64 == 63) {
90-
cur = root(n / 64, i / 64 + 1);
91-
} else {
92-
cur *= arg;
93-
}
94-
}
69+
static point root(size_t n) {
70+
return polar(1., 2. * std::numbers::pi / (ftype)n);
9571
}
96-
template<int step = 1>
72+
template<int step>
9773
static void exec_on_evals(size_t n, auto &&callback) {
74+
point factor = root(4 * step * n);
9875
for(size_t i = 0; i < n; i++) {
99-
callback(i, eval_point(step * i));
100-
}
101-
}
102-
static auto dot_block(size_t k, cvector const& A, cvector const& B) {
103-
auto rt = eval_point(k / flen / 2);
104-
if(k / flen % 2) {
105-
rt = -rt;
76+
callback(i, factor * eval_point(step * i));
10677
}
107-
auto [Ax, Ay] = A.at(k);
108-
auto Bv = B.at(k);
109-
vpoint res = vz;
110-
for (size_t i = 0; i < flen; i++) {
111-
res += vpoint(vz + Ax[i], vz + Ay[i]) * Bv;
112-
real(Bv) = __builtin_shufflevector(real(Bv), real(Bv), 3, 0, 1, 2);
113-
imag(Bv) = __builtin_shufflevector(imag(Bv), imag(Bv), 3, 0, 1, 2);
114-
auto x = real(Bv)[0], y = imag(Bv)[0];
115-
real(Bv)[0] = x * real(rt) - y * imag(rt);
116-
imag(Bv)[0] = x * imag(rt) + y * real(rt);
117-
}
118-
return res;
11978
}
12079

12180
void dot(cvector const& t) {
12281
size_t n = this->size();
123-
for(size_t k = 0; k < n; k += flen) {
124-
set(k, dot_block(k, *this, t));
125-
}
82+
exec_on_evals<1>(n / flen, [&](size_t k, point rt) {
83+
k *= flen;
84+
auto [Ax, Ay] = at(k);
85+
auto Bv = t.at(k);
86+
vpoint res = vz;
87+
for (size_t i = 0; i < flen; i++) {
88+
res += vpoint(vz + Ax[i], vz + Ay[i]) * Bv;
89+
real(Bv) = __builtin_shufflevector(real(Bv), real(Bv), 3, 0, 1, 2);
90+
imag(Bv) = __builtin_shufflevector(imag(Bv), imag(Bv), 3, 0, 1, 2);
91+
auto x = real(Bv)[0], y = imag(Bv)[0];
92+
real(Bv)[0] = x * real(rt) - y * imag(rt);
93+
imag(Bv)[0] = x * imag(rt) + y * real(rt);
94+
}
95+
set(k, res);
96+
});
12697
checkpoint("dot");
12798
}
12899

129100
void ifft() {
130101
size_t n = size();
131102
for(size_t i = flen; i <= n / 2; i *= 2) {
132103
if (4 * i <= n) { // radix-4
133-
exec_on_evals<2>(n / (4 * i), [&](size_t k, point rt) {
104+
exec_on_evals<4>(n / (4 * i), [&](size_t k, point rt) {
134105
k *= 4 * i;
135106
vpoint v1 = {vz + real(rt), vz - imag(rt)};
136107
vpoint v2 = v1 * v1;
@@ -148,7 +119,7 @@ namespace cp_algo::math::fft {
148119
});
149120
i *= 2;
150121
} else { // radix-2 fallback
151-
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
122+
exec_on_evals<2>(n / (2 * i), [&](size_t k, point rt) {
152123
k *= 2 * i;
153124
vpoint cvrt = {vz + real(rt), vz - imag(rt)};
154125
for(size_t j = k; j < k + i; j += flen) {
@@ -169,7 +140,7 @@ namespace cp_algo::math::fft {
169140
for(size_t i = n / 2; i >= flen; i /= 2) {
170141
if (i / 2 >= flen) { // radix-4
171142
i /= 2;
172-
exec_on_evals<2>(n / (4 * i), [&](size_t k, point rt) {
143+
exec_on_evals<4>(n / (4 * i), [&](size_t k, point rt) {
173144
k *= 4 * i;
174145
vpoint v1 = {vz + real(rt), vz + imag(rt)};
175146
vpoint v2 = v1 * v1;
@@ -186,7 +157,7 @@ namespace cp_algo::math::fft {
186157
}
187158
});
188159
} else { // radix-2 fallback
189-
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
160+
exec_on_evals<2>(n / (2 * i), [&](size_t k, point rt) {
190161
k *= 2 * i;
191162
vpoint vrt = {vz + real(rt), vz + imag(rt)};
192163
for(size_t j = k; j < k + i; j += flen) {
@@ -199,17 +170,7 @@ namespace cp_algo::math::fft {
199170
}
200171
checkpoint("fft");
201172
}
202-
static constexpr size_t pre_roots = 1 << 14;
203173
static constexpr size_t pre_evals = 1 << 16;
204-
static constexpr std::array<point, pre_roots> roots = []() {
205-
std::array<point, pre_roots> res = {};
206-
for(size_t n = 1; n < res.size(); n *= 2) {
207-
for(size_t k = 0; k < n; k++) {
208-
res[n + k] = polar(1., std::numbers::pi / ftype(n) * ftype(k));
209-
}
210-
}
211-
return res;
212-
}();
213174
static constexpr std::array<size_t, pre_evals> eval_args = []() {
214175
std::array<size_t, pre_evals> res = {};
215176
for(size_t i = 1; i < pre_evals; i++) {

cp-algo/math/fft.hpp

+19-22
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace cp_algo::math::fft {
2424
}
2525
base cur = factor;
2626
base step = bpow(factor, n);
27-
cvector::exec_on_roots(2 * n, std::min(n, size(a)), [&](size_t i, auto rt) {
27+
for(size_t i = 0; i < std::min(n, size(a)); i++) {
2828
auto splt = [&](size_t i, auto mul) {
2929
auto ai = i < size(a) ? (a[i] * mul).rem_direct() : 0;
3030
auto rem = ai % split;
@@ -33,10 +33,10 @@ namespace cp_algo::math::fft {
3333
};
3434
auto [rai, qai] = splt(i, cur);
3535
auto [rani, qani] = splt(n + i, cur * step);
36-
A.set(i, point(rai, rani) * rt);
37-
B.set(i, point(qai, qani) * rt);
36+
A.set(i, point(rai, rani));
37+
B.set(i, point(qai, qani));
3838
cur *= factor;
39-
});
39+
}
4040
checkpoint("dft init");
4141
if(n) {
4242
A.fft();
@@ -51,11 +51,9 @@ namespace cp_algo::math::fft {
5151
res = {};
5252
return;
5353
}
54-
for(size_t k = 0; k < n; k += flen) {
55-
auto rt = cvector::eval_point(k / flen / 2);
56-
if(k / flen % 2) {
57-
rt = -rt;
58-
}
54+
55+
cvector::exec_on_evals<1>(n / flen, [&](size_t k, point rt) {
56+
k *= flen;
5957
auto [Ax, Ay] = A.at(k);
6058
auto [Bx, By] = B.at(k);
6159
vpoint AC, AD, BC, BD;
@@ -79,33 +77,32 @@ namespace cp_algo::math::fft {
7977
A.at(k) = AC;
8078
C.at(k) = AD + BC;
8179
B.at(k) = BD;
82-
}
80+
});
8381
checkpoint("dot");
8482
A.ifft();
8583
B.ifft();
8684
C.ifft();
8785
auto splitsplit = (base(split) * split).rem();
8886
base cur = ifactor * ifactor;
8987
base step = bpow(ifactor, n);
90-
cvector::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) {
91-
rt = conj(rt);
92-
auto Ai = A.get(i) * rt;
93-
auto Bi = B.get(i) * rt;
94-
auto Ci = C.get(i) * rt;
95-
Int2 A0 = llround(real(Ai));
96-
Int2 A1 = llround(real(Ci));
97-
Int2 A2 = llround(real(Bi));
88+
for(size_t i = 0; i < std::min(n, k); i++) {
89+
auto [Ax, Ay] = A.get(i);
90+
auto [Bx, By] = B.get(i);
91+
auto [Cx, Cy] = C.get(i);
92+
Int2 A0 = llround(Ax);
93+
Int2 A1 = llround(Cx);
94+
Int2 A2 = llround(Bx);
9895
res[i].setr_direct(base::m_reduce(A0 + A1 * split + A2 * splitsplit));
9996
res[i] *= cur;
10097
if(n + i < k) {
101-
Int2 B0 = llround(imag(Ai));
102-
Int2 B1 = llround(imag(Ci));
103-
Int2 B2 = llround(imag(Bi));
98+
Int2 B0 = llround(Ay);
99+
Int2 B1 = llround(Cy);
100+
Int2 B2 = llround(By);
104101
res[n + i].setr_direct(base::m_reduce(B0 + B1 * split + B2 * splitsplit));
105102
res[n + i] *= cur * step;
106103
}
107104
cur *= ifactor;
108-
});
105+
}
109106
checkpoint("recover mod");
110107
}
111108
void mul_inplace(auto &&B, auto& res, size_t k) {

verify/poly/wildcard.test.cpp

-7
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,6 @@ auto round(vftype a) {
3131
return __builtin_convertvector(__builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, v4di), vftype);
3232
}
3333

34-
void print(auto r) {
35-
for(int z = 0; z < 4; z++) {
36-
cout << r[z] << ' ';
37-
}
38-
cout << endl;
39-
}
40-
4134
auto is_integer(auto a) {
4235
static const double eps = 1e-8;
4336
return abs(imag(a)) < eps

0 commit comments

Comments
 (0)