@@ -24,14 +24,6 @@ namespace cp_algo::math::fft {
24
24
r.resize (n / flen);
25
25
checkpoint (" cvector create" );
26
26
}
27
- cvector (cvector const & t) {
28
- r.resize (t.r .size ());
29
- for (size_t i = 0 ; i < r.size (); i++) {
30
- r[i] = {vftype (t.r [i].real ()), vftype (t.r [i].imag ())};
31
- }
32
- checkpoint (" cvector copy" );
33
- }
34
- cvector (cvector&& t) = delete ;
35
27
36
28
vpoint& at (size_t k) {return r[k / flen];}
37
29
vpoint at (size_t k) const {return r[k / flen];}
@@ -63,74 +55,53 @@ namespace cp_algo::math::fft {
63
55
return eval_arg (n / 2 ) | (n & 1 ) << (std::bit_width (n) - 1 );
64
56
}
65
57
}
66
- static auto root (size_t n, size_t k) {
67
- if (n < pre_roots) {
68
- return roots[n + k];
69
- } else if (k % 2 == 0 ) {
70
- return root (n / 2 , k / 2 );
71
- } else {
72
- return polar (1 ., std::numbers::pi / (ftype)n * (ftype)k);
73
- }
74
- }
75
58
static point eval_point (size_t n) {
76
59
if (n % 2 ) {
77
- return eval_point (n - 1 ) * point (0 , 1 );
78
- } else if (n / 2 < pre_evals) {
79
- return evalp[n / 2 ];
60
+ return -eval_point (n - 1 );
61
+ } else if (n % 4 ) {
62
+ return eval_point (n - 2 ) * point (0 , 1 );
63
+ } else if (n / 4 < pre_evals) {
64
+ return evalp[n / 4 ];
80
65
} else {
81
- return root ( 2 * std::bit_floor (n), eval_arg (n));
66
+ return polar ( 1 ., std::numbers:: pi / (ftype) std::bit_floor (n) * (ftype) eval_arg (n));
82
67
}
83
68
}
84
- static void exec_on_roots (size_t n, size_t m, auto &&callback) {
85
- point cur = {1 , 0 };
86
- point arg = root (n, 1 );
87
- for (size_t i = 0 ; i < m; i++) {
88
- callback (i, cur);
89
- if (i % 64 == 63 ) {
90
- cur = root (n / 64 , i / 64 + 1 );
91
- } else {
92
- cur *= arg;
93
- }
94
- }
69
+ static point root (size_t n) {
70
+ return polar (1 ., 2 . * std::numbers::pi / (ftype)n);
95
71
}
96
- template <int step = 1 >
72
+ template <int step>
97
73
static void exec_on_evals (size_t n, auto &&callback) {
74
+ point factor = root (4 * step * n);
98
75
for (size_t i = 0 ; i < n; i++) {
99
- callback (i, eval_point (step * i));
100
- }
101
- }
102
- static auto dot_block (size_t k, cvector const & A, cvector const & B) {
103
- auto rt = eval_point (k / flen / 2 );
104
- if (k / flen % 2 ) {
105
- rt = -rt;
76
+ callback (i, factor * eval_point (step * i));
106
77
}
107
- auto [Ax, Ay] = A.at (k);
108
- auto Bv = B.at (k);
109
- vpoint res = vz;
110
- for (size_t i = 0 ; i < flen; i++) {
111
- res += vpoint (vz + Ax[i], vz + Ay[i]) * Bv;
112
- real (Bv) = __builtin_shufflevector (real (Bv), real (Bv), 3 , 0 , 1 , 2 );
113
- imag (Bv) = __builtin_shufflevector (imag (Bv), imag (Bv), 3 , 0 , 1 , 2 );
114
- auto x = real (Bv)[0 ], y = imag (Bv)[0 ];
115
- real (Bv)[0 ] = x * real (rt) - y * imag (rt);
116
- imag (Bv)[0 ] = x * imag (rt) + y * real (rt);
117
- }
118
- return res;
119
78
}
120
79
121
80
void dot (cvector const & t) {
122
81
size_t n = this ->size ();
123
- for (size_t k = 0 ; k < n; k += flen) {
124
- set (k, dot_block (k, *this , t));
125
- }
82
+ exec_on_evals<1 >(n / flen, [&](size_t k, point rt) {
83
+ k *= flen;
84
+ auto [Ax, Ay] = at (k);
85
+ auto Bv = t.at (k);
86
+ vpoint res = vz;
87
+ for (size_t i = 0 ; i < flen; i++) {
88
+ res += vpoint (vz + Ax[i], vz + Ay[i]) * Bv;
89
+ real (Bv) = __builtin_shufflevector (real (Bv), real (Bv), 3 , 0 , 1 , 2 );
90
+ imag (Bv) = __builtin_shufflevector (imag (Bv), imag (Bv), 3 , 0 , 1 , 2 );
91
+ auto x = real (Bv)[0 ], y = imag (Bv)[0 ];
92
+ real (Bv)[0 ] = x * real (rt) - y * imag (rt);
93
+ imag (Bv)[0 ] = x * imag (rt) + y * real (rt);
94
+ }
95
+ set (k, res);
96
+ });
126
97
checkpoint (" dot" );
127
98
}
128
99
129
100
void ifft () {
130
101
size_t n = size ();
131
102
for (size_t i = flen; i <= n / 2 ; i *= 2 ) {
132
103
if (4 * i <= n) { // radix-4
133
- exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
104
+ exec_on_evals<4 >(n / (4 * i), [&](size_t k, point rt) {
134
105
k *= 4 * i;
135
106
vpoint v1 = {vz + real (rt), vz - imag (rt)};
136
107
vpoint v2 = v1 * v1;
@@ -148,7 +119,7 @@ namespace cp_algo::math::fft {
148
119
});
149
120
i *= 2 ;
150
121
} else { // radix-2 fallback
151
- exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
122
+ exec_on_evals< 2 > (n / (2 * i), [&](size_t k, point rt) {
152
123
k *= 2 * i;
153
124
vpoint cvrt = {vz + real (rt), vz - imag (rt)};
154
125
for (size_t j = k; j < k + i; j += flen) {
@@ -169,7 +140,7 @@ namespace cp_algo::math::fft {
169
140
for (size_t i = n / 2 ; i >= flen; i /= 2 ) {
170
141
if (i / 2 >= flen) { // radix-4
171
142
i /= 2 ;
172
- exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
143
+ exec_on_evals<4 >(n / (4 * i), [&](size_t k, point rt) {
173
144
k *= 4 * i;
174
145
vpoint v1 = {vz + real (rt), vz + imag (rt)};
175
146
vpoint v2 = v1 * v1;
@@ -186,7 +157,7 @@ namespace cp_algo::math::fft {
186
157
}
187
158
});
188
159
} else { // radix-2 fallback
189
- exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
160
+ exec_on_evals< 2 > (n / (2 * i), [&](size_t k, point rt) {
190
161
k *= 2 * i;
191
162
vpoint vrt = {vz + real (rt), vz + imag (rt)};
192
163
for (size_t j = k; j < k + i; j += flen) {
@@ -199,17 +170,7 @@ namespace cp_algo::math::fft {
199
170
}
200
171
checkpoint (" fft" );
201
172
}
202
- static constexpr size_t pre_roots = 1 << 14 ;
203
173
static constexpr size_t pre_evals = 1 << 16 ;
204
- static constexpr std::array<point, pre_roots> roots = []() {
205
- std::array<point, pre_roots> res = {};
206
- for (size_t n = 1 ; n < res.size (); n *= 2 ) {
207
- for (size_t k = 0 ; k < n; k++) {
208
- res[n + k] = polar (1 ., std::numbers::pi / ftype (n) * ftype (k));
209
- }
210
- }
211
- return res;
212
- }();
213
174
static constexpr std::array<size_t , pre_evals> eval_args = []() {
214
175
std::array<size_t , pre_evals> res = {};
215
176
for (size_t i = 1 ; i < pre_evals; i++) {
0 commit comments