Skip to content

Commit a6e235b

Browse files
committed
factor out rotate_right
1 parent 88a38e3 commit a6e235b

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

cp-algo/math/cvector.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace stdx = std::experimental;
1010
namespace cp_algo::math::fft {
1111
static constexpr size_t flen = 4;
1212
using ftype = double;
13-
using vftype = simd<ftype, flen>;
13+
using vftype = dx4;
1414
using point = complex<ftype>;
1515
using vpoint = complex<vftype>;
1616
static constexpr vftype vz = {};
@@ -91,8 +91,8 @@ namespace cp_algo::math::fft {
9191
vpoint res = vz;
9292
for (size_t i = 0; i < flen; i++) {
9393
res += vpoint(vz + Ax[i], vz + Ay[i]) * Bv;
94-
real(Bv) = __builtin_shufflevector(real(Bv), real(Bv), 3, 0, 1, 2);
95-
imag(Bv) = __builtin_shufflevector(imag(Bv), imag(Bv), 3, 0, 1, 2);
94+
real(Bv) = rotate_right(real(Bv));
95+
imag(Bv) = rotate_right(imag(Bv));
9696
auto x = real(Bv)[0], y = imag(Bv)[0];
9797
real(Bv)[0] = x * real(rt) - y * imag(rt);
9898
imag(Bv)[0] = x * imag(rt) + y * real(rt);

cp-algo/math/fft.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ namespace cp_algo::math::fft {
6363
vpoint Av = {vz + Ax[i], vz + Ay[i]}, Bv = {vz + Bx[i], vz + By[i]};
6464
AC += Av * Cv; AD += Av * Dv;
6565
BC += Bv * Cv; BD += Bv * Dv;
66-
real(Cv) = __builtin_shufflevector(real(Cv), real(Cv), 3, 0, 1, 2);
67-
imag(Cv) = __builtin_shufflevector(imag(Cv), imag(Cv), 3, 0, 1, 2);
68-
real(Dv) = __builtin_shufflevector(real(Dv), real(Dv), 3, 0, 1, 2);
69-
imag(Dv) = __builtin_shufflevector(imag(Dv), imag(Dv), 3, 0, 1, 2);
66+
real(Cv) = rotate_right(real(Cv));
67+
imag(Cv) = rotate_right(imag(Cv));
68+
real(Dv) = rotate_right(real(Dv));
69+
imag(Dv) = rotate_right(imag(Dv));
7070
auto cx = real(Cv)[0], cy = imag(Cv)[0];
7171
auto dx = real(Dv)[0], dy = imag(Dv)[0];
7272
real(Cv)[0] = cx * real(rt) - cy * imag(rt);

cp-algo/util/simd.hpp

+13-7
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,25 @@ namespace cp_algo {
1010
using u64x4 = simd<uint64_t, 4>;
1111
using u32x8 = simd<uint32_t, 8>;
1212
using u32x4 = simd<uint32_t, 4>;
13+
using dx4 = simd<double, 4>;
1314

14-
template<typename Simd>
15-
Simd abs(Simd a) {
15+
dx4 abs(dx4 a) {
1616
#ifdef __AVX2__
17-
return _mm256_and_pd(a, Simd{} + 1/0.);
17+
return _mm256_and_pd(a, dx4{} + 1/0.);
1818
#else
1919
return a < 0 ? -a : a;
2020
#endif
2121
}
2222

23-
template<typename Simd>
24-
i64x4 lround(Simd a) {
23+
i64x4 lround(dx4 a) {
2524
#ifdef __AVX2__
2625
return __builtin_convertvector(_mm256_round_pd(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC), i64x4);
2726
#else
2827
return __builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, i64x4);
2928
#endif
3029
}
3130

32-
template<typename Simd>
33-
Simd round(Simd a) {
31+
dx4 round(dx4 a) {
3432
#ifdef __AVX2__
3533
return _mm256_round_pd(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
3634
#else
@@ -55,6 +53,14 @@ namespace cp_algo {
5553
return montgomery_reduce(u64x4(_mm256_mul_epu32(__m256i(x), __m256i(y))), mod, imod);
5654
#else
5755
return montgomery_reduce(x * y, mod, imod);
56+
#endif
57+
}
58+
59+
dx4 rotate_right(dx4 x) {
60+
#ifdef __AVX2__
61+
return _mm256_permute4x64_pd(x, _MM_SHUFFLE(2, 1, 0, 3));
62+
#else
63+
return __builtin_shufflevector(x, x, 3, 0, 1, 2);
5864
#endif
5965
}
6066
}

0 commit comments

Comments
 (0)