Skip to content

Commit 1a4696d

Browse files
committed
[libc][NFC] Use new approach based on types to code memset
1 parent 991c7e1 commit 1a4696d

File tree

4 files changed

+86
-91
lines changed

4 files changed

+86
-91
lines changed

libc/src/string/memory_utils/memset_implementations.h

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,86 +26,101 @@ namespace __llvm_libc {
2626
inline_memset_embedded_tiny(Ptr dst, uint8_t value, size_t count) {
2727
LIBC_LOOP_NOUNROLL
2828
for (size_t offset = 0; offset < count; ++offset)
29-
generic::Memset<1, 1>::block(dst + offset, value);
29+
generic::Memset<uint8_t>::block(dst + offset, value);
3030
}
3131

3232
#if defined(LIBC_TARGET_ARCH_IS_X86)
33-
template <size_t MaxSize>
3433
[[maybe_unused]] LIBC_INLINE static void
3534
inline_memset_x86(Ptr dst, uint8_t value, size_t count) {
35+
#if defined(__AVX512F__)
36+
using uint128_t = uint8x16_t;
37+
using uint256_t = uint8x32_t;
38+
using uint512_t = uint8x64_t;
39+
#elif defined(__AVX__)
40+
using uint128_t = uint8x16_t;
41+
using uint256_t = uint8x32_t;
42+
using uint512_t = cpp::array<uint8x32_t, 2>;
43+
#elif defined(__SSE2__)
44+
using uint128_t = uint8x16_t;
45+
using uint256_t = cpp::array<uint8x16_t, 2>;
46+
using uint512_t = cpp::array<uint8x16_t, 4>;
47+
#else
48+
using uint128_t = cpp::array<uint64_t, 2>;
49+
using uint256_t = cpp::array<uint64_t, 4>;
50+
using uint512_t = cpp::array<uint64_t, 8>;
51+
#endif
52+
3653
if (count == 0)
3754
return;
3855
if (count == 1)
39-
return generic::Memset<1, MaxSize>::block(dst, value);
56+
return generic::Memset<uint8_t>::block(dst, value);
4057
if (count == 2)
41-
return generic::Memset<2, MaxSize>::block(dst, value);
58+
return generic::Memset<uint16_t>::block(dst, value);
4259
if (count == 3)
43-
return generic::Memset<3, MaxSize>::block(dst, value);
60+
return generic::Memset<uint16_t, uint8_t>::block(dst, value);
4461
if (count <= 8)
45-
return generic::Memset<4, MaxSize>::head_tail(dst, value, count);
62+
return generic::Memset<uint32_t>::head_tail(dst, value, count);
4663
if (count <= 16)
47-
return generic::Memset<8, MaxSize>::head_tail(dst, value, count);
64+
return generic::Memset<uint64_t>::head_tail(dst, value, count);
4865
if (count <= 32)
49-
return generic::Memset<16, MaxSize>::head_tail(dst, value, count);
66+
return generic::Memset<uint128_t>::head_tail(dst, value, count);
5067
if (count <= 64)
51-
return generic::Memset<32, MaxSize>::head_tail(dst, value, count);
68+
return generic::Memset<uint256_t>::head_tail(dst, value, count);
5269
if (count <= 128)
53-
return generic::Memset<64, MaxSize>::head_tail(dst, value, count);
70+
return generic::Memset<uint512_t>::head_tail(dst, value, count);
5471
// Aligned loop
55-
generic::Memset<32, MaxSize>::block(dst, value);
72+
generic::Memset<uint256_t>::block(dst, value);
5673
align_to_next_boundary<32>(dst, count);
57-
return generic::Memset<32, MaxSize>::loop_and_tail(dst, value, count);
74+
return generic::Memset<uint256_t>::loop_and_tail(dst, value, count);
5875
}
5976
#endif // defined(LIBC_TARGET_ARCH_IS_X86)
6077

6178
#if defined(LIBC_TARGET_ARCH_IS_AARCH64)
62-
template <size_t MaxSize>
6379
[[maybe_unused]] LIBC_INLINE static void
6480
inline_memset_aarch64(Ptr dst, uint8_t value, size_t count) {
81+
static_assert(aarch64::kNeon, "aarch64 supports vector types");
82+
using uint128_t = uint8x16_t;
83+
using uint256_t = uint8x32_t;
84+
using uint512_t = uint8x64_t;
6585
if (count == 0)
6686
return;
6787
if (count <= 3) {
68-
generic::Memset<1, MaxSize>::block(dst, value);
88+
generic::Memset<uint8_t>::block(dst, value);
6989
if (count > 1)
70-
generic::Memset<2, MaxSize>::tail(dst, value, count);
90+
generic::Memset<uint16_t>::tail(dst, value, count);
7191
return;
7292
}
7393
if (count <= 8)
74-
return generic::Memset<4, MaxSize>::head_tail(dst, value, count);
94+
return generic::Memset<uint32_t>::head_tail(dst, value, count);
7595
if (count <= 16)
76-
return generic::Memset<8, MaxSize>::head_tail(dst, value, count);
96+
return generic::Memset<uint64_t>::head_tail(dst, value, count);
7797
if (count <= 32)
78-
return generic::Memset<16, MaxSize>::head_tail(dst, value, count);
98+
return generic::Memset<uint128_t>::head_tail(dst, value, count);
7999
if (count <= (32 + 64)) {
80-
generic::Memset<32, MaxSize>::block(dst, value);
100+
generic::Memset<uint256_t>::block(dst, value);
81101
if (count <= 64)
82-
return generic::Memset<32, MaxSize>::tail(dst, value, count);
83-
generic::Memset<32, MaxSize>::block(dst + 32, value);
84-
generic::Memset<32, MaxSize>::tail(dst, value, count);
102+
return generic::Memset<uint256_t>::tail(dst, value, count);
103+
generic::Memset<uint256_t>::block(dst + 32, value);
104+
generic::Memset<uint256_t>::tail(dst, value, count);
85105
return;
86106
}
87107
if (count >= 448 && value == 0 && aarch64::neon::hasZva()) {
88-
generic::Memset<64, MaxSize>::block(dst, 0);
108+
generic::Memset<uint512_t>::block(dst, 0);
89109
align_to_next_boundary<64>(dst, count);
90-
return aarch64::neon::BzeroCacheLine<64>::loop_and_tail(dst, 0, count);
110+
return aarch64::neon::BzeroCacheLine::loop_and_tail(dst, 0, count);
91111
} else {
92-
generic::Memset<16, MaxSize>::block(dst, value);
112+
generic::Memset<uint128_t>::block(dst, value);
93113
align_to_next_boundary<16>(dst, count);
94-
return generic::Memset<64, MaxSize>::loop_and_tail(dst, value, count);
114+
return generic::Memset<uint512_t>::loop_and_tail(dst, value, count);
95115
}
96116
}
97117
#endif // defined(LIBC_TARGET_ARCH_IS_AARCH64)
98118

99119
LIBC_INLINE static void inline_memset(Ptr dst, uint8_t value, size_t count) {
100120
#if defined(LIBC_TARGET_ARCH_IS_X86)
101-
static constexpr size_t kMaxSize = x86::kAvx512F ? 64
102-
: x86::kAvx ? 32
103-
: x86::kSse2 ? 16
104-
: 8;
105-
return inline_memset_x86<kMaxSize>(dst, value, count);
121+
return inline_memset_x86(dst, value, count);
106122
#elif defined(LIBC_TARGET_ARCH_IS_AARCH64)
107-
static constexpr size_t kMaxSize = aarch64::kNeon ? 16 : 8;
108-
return inline_memset_aarch64<kMaxSize>(dst, value, count);
123+
return inline_memset_aarch64(dst, value, count);
109124
#else
110125
return inline_memset_embedded_tiny(dst, value, count);
111126
#endif

libc/src/string/memory_utils/op_aarch64.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@ static inline constexpr bool kNeon = LLVM_LIBC_IS_DEFINED(__ARM_NEON);
3030

3131
namespace neon {
3232

33-
template <size_t Size> struct BzeroCacheLine {
34-
static constexpr size_t SIZE = Size;
33+
struct BzeroCacheLine {
34+
static constexpr size_t SIZE = 64;
3535

3636
LIBC_INLINE static void block(Ptr dst, uint8_t) {
37-
static_assert(Size == 64);
3837
#if __SIZEOF_POINTER__ == 4
3938
asm("dc zva, %w[dst]" : : [dst] "r"(dst) : "memory");
4039
#else
@@ -43,15 +42,13 @@ template <size_t Size> struct BzeroCacheLine {
4342
}
4443

4544
LIBC_INLINE static void loop_and_tail(Ptr dst, uint8_t value, size_t count) {
46-
static_assert(Size > 1, "a loop of size 1 does not need tail");
4745
size_t offset = 0;
4846
do {
4947
block(dst + offset, value);
5048
offset += SIZE;
5149
} while (offset < count - SIZE);
5250
// Unaligned store, we can't use 'dc zva' here.
53-
static constexpr size_t kMaxSize = kNeon ? 16 : 8;
54-
generic::Memset<Size, kMaxSize>::tail(dst, value, count);
51+
generic::Memset<uint8x64_t>::tail(dst, value, count);
5552
}
5653
};
5754

libc/src/string/memory_utils/op_generic.h

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333

3434
#include <stdint.h>
3535

36-
namespace __llvm_libc::generic {
37-
36+
namespace __llvm_libc {
3837
// Compiler types using the vector attributes.
3938
using uint8x1_t = uint8_t __attribute__((__vector_size__(1)));
4039
using uint8x2_t = uint8_t __attribute__((__vector_size__(2)));
@@ -43,13 +42,14 @@ using uint8x8_t = uint8_t __attribute__((__vector_size__(8)));
4342
using uint8x16_t = uint8_t __attribute__((__vector_size__(16)));
4443
using uint8x32_t = uint8_t __attribute__((__vector_size__(32)));
4544
using uint8x64_t = uint8_t __attribute__((__vector_size__(64)));
45+
} // namespace __llvm_libc
4646

47+
namespace __llvm_libc::generic {
4748
// We accept three types of values as elements for generic operations:
4849
// - scalar : unsigned integral types
4950
// - vector : compiler types using the vector attributes
5051
// - array : a cpp::array<T, N> where T is itself either a scalar or a vector.
5152
// The following traits help discriminate between these cases.
52-
5353
template <typename T>
5454
constexpr bool is_scalar_v = cpp::is_integral_v<T> && cpp::is_unsigned_v<T>;
5555

@@ -109,23 +109,11 @@ template <typename T> T splat(uint8_t value) {
109109
T Out;
110110
// This for loop is optimized out for vector types.
111111
for (size_t i = 0; i < sizeof(T); ++i)
112-
Out[i] = static_cast<uint8_t>(value);
112+
Out[i] = value;
113113
return Out;
114114
}
115115
}
116116

117-
template <typename T> void set(Ptr dst, uint8_t value) {
118-
static_assert(is_element_type_v<T>);
119-
if constexpr (is_scalar_v<T> || is_vector_v<T>) {
120-
store<T>(dst, splat<T>(value));
121-
} else if constexpr (is_array_v<T>) {
122-
using value_type = typename T::value_type;
123-
const value_type Splat = splat<value_type>(value);
124-
for (size_t I = 0; I < array_size_v<T>; ++I)
125-
store<value_type>(dst + (I * sizeof(value_type)), Splat);
126-
}
127-
}
128-
129117
static_assert((UINTPTR_MAX == 4294967295U) ||
130118
(UINTPTR_MAX == 18446744073709551615UL),
131119
"We currently only support 32- or 64-bit platforms");
@@ -149,9 +137,7 @@ constexpr bool is_decreasing_size() {
149137
}
150138

151139
template <size_t Size, typename... Ts> struct Largest;
152-
template <size_t Size> struct Largest<Size> {
153-
using type = uint8_t;
154-
};
140+
template <size_t Size> struct Largest<Size> : cpp::type_identity<uint8_t> {};
155141
template <size_t Size, typename T, typename... Ts>
156142
struct Largest<Size, T, Ts...> {
157143
using next = Largest<Size, Ts...>;
@@ -179,6 +165,11 @@ template <typename First, typename... Ts> struct SupportedTypes {
179165
using TypeFor = typename details::Largest<Size, First, Ts...>::type;
180166
};
181167

168+
// Returns the sum of the sizeof of all the TS types.
169+
template <typename... TS> static constexpr size_t sum_sizeof() {
170+
return (... + sizeof(TS));
171+
}
172+
182173
// Map from sizes to structures offering static load, store and splat methods.
183174
// Note: On platforms lacking vector support, we use the ArrayType below and
184175
// decompose the operation in smaller pieces.
@@ -220,27 +211,23 @@ using getTypeFor = cpp::conditional_t<
220211

221212
///////////////////////////////////////////////////////////////////////////////
222213
// Memset
223-
// The MaxSize template argument gives the maximum size handled natively by the
224-
// platform. For instance on x86 with AVX support this would be 32. If a size
225-
// greater than MaxSize is requested we break the operation down in smaller
226-
// pieces of size MaxSize.
227214
///////////////////////////////////////////////////////////////////////////////
228-
template <size_t Size, size_t MaxSize> struct Memset {
229-
static_assert(is_power2(MaxSize));
230-
static constexpr size_t SIZE = Size;
215+
216+
template <typename T, typename... TS> struct Memset {
217+
static constexpr size_t SIZE = sum_sizeof<T, TS...>();
231218

232219
LIBC_INLINE static void block(Ptr dst, uint8_t value) {
233-
if constexpr (Size == 3) {
234-
Memset<1, MaxSize>::block(dst + 2, value);
235-
Memset<2, MaxSize>::block(dst, value);
236-
} else {
237-
using T = details::getTypeFor<Size, MaxSize>;
238-
if constexpr (details::is_void_v<T>) {
239-
deferred_static_assert("Unimplemented Size");
240-
} else {
241-
set<T>(dst, value);
242-
}
220+
static_assert(is_element_type_v<T>);
221+
if constexpr (is_scalar_v<T> || is_vector_v<T>) {
222+
store<T>(dst, splat<T>(value));
223+
} else if constexpr (is_array_v<T>) {
224+
using value_type = typename T::value_type;
225+
const auto Splat = splat<value_type>(value);
226+
for (size_t I = 0; I < array_size_v<T>; ++I)
227+
store<value_type>(dst + (I * sizeof(value_type)), Splat);
243228
}
229+
if constexpr (sizeof...(TS))
230+
Memset<TS...>::block(dst + sizeof(T), value);
244231
}
245232

246233
LIBC_INLINE static void tail(Ptr dst, uint8_t value, size_t count) {
@@ -253,7 +240,7 @@ template <size_t Size, size_t MaxSize> struct Memset {
253240
}
254241

255242
LIBC_INLINE static void loop_and_tail(Ptr dst, uint8_t value, size_t count) {
256-
static_assert(SIZE > 1);
243+
static_assert(SIZE > 1, "a loop of size 1 does not need tail");
257244
size_t offset = 0;
258245
do {
259246
block(dst + offset, value);

libc/test/src/string/memory_utils/op_tests.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,24 +119,20 @@ using MemsetImplementations = testing::TypeList<
119119
builtin::Memset<64>,
120120
#endif
121121
#ifdef LLVM_LIBC_HAS_UINT64
122-
generic::Memset<8, 8>, //
123-
generic::Memset<16, 8>, //
124-
generic::Memset<32, 8>, //
125-
generic::Memset<64, 8>, //
122+
generic::Memset<uint64_t>, generic::Memset<cpp::array<uint64_t, 2>>,
126123
#endif
127124
#ifdef __AVX512F__
128-
generic::Memset<64, 64>, // prevents warning about avx512f
125+
generic::Memset<uint8x64_t>, generic::Memset<cpp::array<uint8x64_t, 2>>,
129126
#endif
130-
generic::Memset<1, 1>, //
131-
generic::Memset<2, 1>, //
132-
generic::Memset<2, 2>, //
133-
generic::Memset<4, 2>, //
134-
generic::Memset<4, 4>, //
135-
generic::Memset<16, 16>, //
136-
generic::Memset<32, 16>, //
137-
generic::Memset<64, 16>, //
138-
generic::Memset<32, 32>, //
139-
generic::Memset<64, 32> //
127+
#ifdef __AVX__
128+
generic::Memset<uint8x32_t>, generic::Memset<cpp::array<uint8x32_t, 2>>,
129+
#endif
130+
#ifdef __SSE2__
131+
generic::Memset<uint8x16_t>, generic::Memset<cpp::array<uint8x16_t, 2>>,
132+
#endif
133+
generic::Memset<uint32_t>, generic::Memset<cpp::array<uint32_t, 2>>, //
134+
generic::Memset<uint16_t>, generic::Memset<cpp::array<uint16_t, 2>>, //
135+
generic::Memset<uint8_t>, generic::Memset<cpp::array<uint8_t, 2>> //
140136
>;
141137

142138
// Adapt CheckMemset signature to op implementation signatures.

0 commit comments

Comments
 (0)