Skip to content

Commit f8eeaf5

Browse files
committed
Add generic & safe (but unstable) implementation using std::simd
1 parent 86481c8 commit f8eeaf5

File tree

7 files changed

+225
-13
lines changed

7 files changed

+225
-13
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ cfg-if = "1"
1919

2020
[profile.release]
2121
debug = true
22+
23+
[features]
24+
stdsimd = []

src/aarch64.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ static MD: [u8; 16] = [
2626

2727
impl Vector for uint8x16_t {
2828
const LANES: usize = 16;
29+
type Mask = Self;
2930

3031
#[inline]
3132
unsafe fn splat(a: u8) -> Self {
@@ -58,6 +59,7 @@ impl Vector for uint8x16_t {
5859

5960
impl Vector for uint8x8_t {
6061
const LANES: usize = 8;
62+
type Mask = Self;
6163

6264
#[inline]
6365
unsafe fn splat(a: u8) -> Self {
@@ -96,6 +98,7 @@ struct uint8x4_t(uint8x8_t);
9698

9799
impl Vector for uint8x4_t {
98100
const LANES: usize = 4;
101+
type Mask = Self;
99102

100103
#[inline]
101104
unsafe fn splat(a: u8) -> Self {
@@ -136,6 +139,7 @@ struct uint8x2_t(uint8x8_t);
136139

137140
impl Vector for uint8x2_t {
138141
const LANES: usize = 2;
142+
type Mask = Self;
139143

140144
#[inline]
141145
unsafe fn splat(a: u8) -> Self {

src/lib.rs

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@
1111
allow(stable_features),
1212
feature(aarch64_target_feature)
1313
)]
14+
#![cfg_attr(feature = "stdsimd", feature(portable_simd))]
1415

1516
/// Substring search implementations using aarch64 architecture features.
1617
#[cfg(target_arch = "aarch64")]
1718
pub mod aarch64;
1819

20+
/// Substring search implementations using generic stdsimd features.
21+
#[cfg(feature = "stdsimd")]
22+
pub mod stdsimd;
23+
1924
/// Substring search implementations using x86 architecture features.
2025
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2126
pub mod x86;
@@ -151,21 +156,24 @@ impl MemchrSearcher {
151156
trait Vector: Copy {
152157
const LANES: usize;
153158

159+
type Mask;
160+
154161
unsafe fn splat(a: u8) -> Self;
155162

156163
unsafe fn load(a: *const u8) -> Self;
157164

158-
unsafe fn lanes_eq(a: Self, b: Self) -> Self;
165+
unsafe fn lanes_eq(a: Self, b: Self) -> Self::Mask;
159166

160-
unsafe fn bitwise_and(a: Self, b: Self) -> Self;
167+
unsafe fn bitwise_and(a: Self::Mask, b: Self::Mask) -> Self::Mask;
161168

162-
unsafe fn to_bitmask(a: Self) -> i32;
169+
unsafe fn to_bitmask(a: Self::Mask) -> i32;
163170
}
164171

165172
/// Hash of the first and "last" bytes in the needle for use with the SIMD
166173
/// algorithm implemented by `Avx2Searcher::vector_search_in`. As explained, any
167174
/// byte can be chosen to represent the "last" byte of the hash to prevent
168175
/// worst-case attacks.
176+
#[derive(Debug)]
169177
struct VectorHash<V: Vector> {
170178
first: V,
171179
last: V,
@@ -174,8 +182,8 @@ struct VectorHash<V: Vector> {
174182
impl<V: Vector> VectorHash<V> {
175183
unsafe fn new(first: u8, last: u8) -> Self {
176184
Self {
177-
first: Vector::splat(first),
178-
last: Vector::splat(last),
185+
first: V::splat(first),
186+
last: V::splat(last),
179187
}
180188
}
181189
}
@@ -206,14 +214,14 @@ trait Searcher<N: NeedleWithSize + ?Sized> {
206214
start: *const u8,
207215
mask: i32,
208216
) -> bool {
209-
let first = Vector::load(start);
210-
let last = Vector::load(start.add(self.position()));
217+
let first = V::load(start);
218+
let last = V::load(start.add(self.position()));
211219

212-
let eq_first = Vector::lanes_eq(hash.first, first);
213-
let eq_last = Vector::lanes_eq(hash.last, last);
220+
let eq_first = V::lanes_eq(hash.first, first);
221+
let eq_last = V::lanes_eq(hash.last, last);
214222

215-
let eq = Vector::bitwise_and(eq_first, eq_last);
216-
let mut eq = (Vector::to_bitmask(eq) & mask) as u32;
223+
let eq = V::bitwise_and(eq_first, eq_last);
224+
let mut eq = (V::to_bitmask(eq) & mask) as u32;
217225

218226
let start = start as usize - haystack.as_ptr() as usize;
219227
let chunk = haystack.as_ptr().add(start + 1);
@@ -379,10 +387,19 @@ mod tests {
379387

380388
let searcher = unsafe { NeonSearcher::with_position(needle, position) };
381389
assert_eq!(unsafe { searcher.search_in(haystack) }, result);
382-
} else {
390+
} else if #[cfg(not(feature = "stdsimd"))] {
383391
compile_error!("Unsupported architecture");
384392
}
385393
}
394+
395+
cfg_if::cfg_if! {
396+
if #[cfg(feature = "stdsimd")] {
397+
use crate::stdsimd::StdSimdSearcher;
398+
399+
let searcher = StdSimdSearcher::with_position(needle, position);
400+
assert_eq!(searcher.search_in(haystack), result);
401+
}
402+
}
386403
}
387404

388405
result

src/stdsimd.rs

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#![allow(clippy::missing_safety_doc)]
2+
3+
use crate::{Needle, NeedleWithSize, Searcher, Vector, VectorHash};
4+
#[cfg(feature = "stdsimd")]
5+
use std::simd::*;
6+
7+
trait ToFixedBitMask: Sized {
8+
fn to_fixed_bitmask(self) -> u32;
9+
}
10+
11+
impl<const LANES: usize> ToFixedBitMask for Mask<i8, LANES>
12+
where
13+
LaneCount<LANES>: SupportedLaneCount,
14+
Self: ToBitMask,
15+
<Self as ToBitMask>::BitMask: Into<u32>,
16+
{
17+
#[inline]
18+
fn to_fixed_bitmask(self) -> u32 {
19+
self.to_bitmask().into()
20+
}
21+
}
22+
23+
impl<const LANES: usize> Vector for Simd<u8, LANES>
24+
where
25+
LaneCount<LANES>: SupportedLaneCount,
26+
Mask<i8, LANES>: ToFixedBitMask,
27+
{
28+
const LANES: usize = LANES;
29+
type Mask = Mask<i8, LANES>;
30+
31+
#[inline]
32+
unsafe fn splat(a: u8) -> Self {
33+
Simd::splat(a as u8)
34+
}
35+
36+
#[inline]
37+
unsafe fn load(a: *const u8) -> Self {
38+
std::ptr::read_unaligned(a as *const Self)
39+
}
40+
41+
#[inline]
42+
unsafe fn lanes_eq(a: Self, b: Self) -> Self::Mask {
43+
a.lanes_eq(b)
44+
}
45+
46+
#[inline]
47+
unsafe fn bitwise_and(a: Self::Mask, b: Self::Mask) -> Self::Mask {
48+
a & b
49+
}
50+
51+
#[inline]
52+
unsafe fn to_bitmask(a: Self::Mask) -> i32 {
53+
std::mem::transmute(a.to_fixed_bitmask())
54+
}
55+
}
56+
57+
type Simd2 = Simd<u8, 2>;
58+
type Simd4 = Simd<u8, 4>;
59+
type Simd8 = Simd<u8, 8>;
60+
type Simd16 = Simd<u8, 16>;
61+
type Simd32 = Simd<u8, 32>;
62+
63+
fn from_hash<const N1: usize, const N2: usize>(
64+
hash: &VectorHash<Simd<u8, N1>>,
65+
) -> VectorHash<Simd<u8, N2>>
66+
where
67+
LaneCount<N1>: SupportedLaneCount,
68+
Mask<i8, N1>: ToFixedBitMask,
69+
LaneCount<N2>: SupportedLaneCount,
70+
Mask<i8, N2>: ToFixedBitMask,
71+
{
72+
VectorHash {
73+
first: Simd::splat(hash.first.as_array()[0]),
74+
last: Simd::splat(hash.last.as_array()[0]),
75+
}
76+
}
77+
78+
/// Searcher for portable simd.
79+
pub struct StdSimdSearcher<N: Needle> {
80+
needle: N,
81+
position: usize,
82+
simd32_hash: VectorHash<Simd32>,
83+
}
84+
85+
impl<N: Needle> Searcher<N> for StdSimdSearcher<N> {
86+
fn needle(&self) -> &N {
87+
&self.needle
88+
}
89+
90+
fn position(&self) -> usize {
91+
self.position
92+
}
93+
}
94+
95+
impl<N: Needle> StdSimdSearcher<N> {
96+
/// Creates a new searcher for `needle`. By default, `position` is set to
97+
/// the last character in the needle.
98+
///
99+
/// # Panics
100+
///
101+
/// Panics if `needle` is empty or if the associated `SIZE` constant does
102+
/// not correspond to the actual size of `needle`.
103+
pub fn new(needle: N) -> Self {
104+
// Wrapping prevents panicking on unsigned integer underflow when
105+
// `needle` is empty.
106+
let position = needle.size().wrapping_sub(1);
107+
Self::with_position(needle, position)
108+
}
109+
110+
/// Same as `new` but allows additionally specifying the `position` to use.
111+
///
112+
/// # Panics
113+
///
114+
/// Panics if `needle` is empty, if `position` is not a valid index for
115+
/// `needle` or if the associated `SIZE` constant does not correspond to the
116+
/// actual size of `needle`.
117+
#[inline]
118+
pub fn with_position(needle: N, position: usize) -> Self {
119+
// Implicitly checks that the needle is not empty because position is an
120+
// unsized integer.
121+
assert!(position < needle.size());
122+
123+
let bytes = needle.as_bytes();
124+
if let Some(size) = N::SIZE {
125+
assert_eq!(size, bytes.len());
126+
}
127+
128+
let simd32_hash = unsafe { VectorHash::new(bytes[0], bytes[position]) };
129+
130+
Self {
131+
position,
132+
simd32_hash,
133+
needle,
134+
}
135+
}
136+
137+
/// Inlined version of `search_in` for hot call sites.
138+
#[inline]
139+
pub fn inlined_search_in(&self, haystack: &[u8]) -> bool {
140+
if haystack.len() <= self.needle.size() {
141+
return haystack == self.needle.as_bytes();
142+
}
143+
144+
let end = haystack.len() - self.needle.size() + 1;
145+
146+
if end < Simd2::LANES {
147+
unreachable!();
148+
} else if end < Simd4::LANES {
149+
let hash = from_hash::<32, 2>(&self.simd32_hash);
150+
println!("hash: {:?}", hash);
151+
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
152+
} else if end < Simd8::LANES {
153+
let hash = from_hash::<32, 4>(&self.simd32_hash);
154+
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
155+
} else if end < Simd16::LANES {
156+
let hash = from_hash::<32, 8>(&self.simd32_hash);
157+
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
158+
} else if end < Simd32::LANES {
159+
let hash = from_hash::<32, 16>(&self.simd32_hash);
160+
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
161+
} else {
162+
unsafe { self.vector_search_in_default_version(haystack, end, &self.simd32_hash) }
163+
}
164+
}
165+
166+
/// Performs a substring search for the `needle` within `haystack`.
167+
pub fn search_in(&self, haystack: &[u8]) -> bool {
168+
self.inlined_search_in(haystack)
169+
}
170+
}

src/wasm32.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::arch::wasm32::*;
66

77
impl Vector for v128 {
88
const LANES: usize = 16;
9+
type Mask = Self;
910

1011
#[inline]
1112
#[target_feature(enable = "simd128")]
@@ -45,6 +46,7 @@ struct v64(v128);
4546

4647
impl Vector for v64 {
4748
const LANES: usize = 8;
49+
type Mask = Self;
4850

4951
#[inline]
5052
#[target_feature(enable = "simd128")]
@@ -90,6 +92,7 @@ struct v32(v128);
9092

9193
impl Vector for v32 {
9294
const LANES: usize = 4;
95+
type Mask = Self;
9396

9497
#[inline]
9598
#[target_feature(enable = "simd128")]
@@ -135,6 +138,7 @@ struct v16(v128);
135138

136139
impl Vector for v16 {
137140
const LANES: usize = 2;
141+
type Mask = Self;
138142

139143
#[inline]
140144
#[target_feature(enable = "simd128")]

src/x86.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct __m16i(__m128i);
3030

3131
impl Vector for __m16i {
3232
const LANES: usize = 2;
33+
type Mask = Self;
3334

3435
#[inline]
3536
#[target_feature(enable = "avx2")]
@@ -76,6 +77,7 @@ struct __m32i(__m128i);
7677

7778
impl Vector for __m32i {
7879
const LANES: usize = 4;
80+
type Mask = Self;
7981

8082
#[inline]
8183
#[target_feature(enable = "avx2")]
@@ -122,6 +124,7 @@ struct __m64i(__m128i);
122124

123125
impl Vector for __m64i {
124126
const LANES: usize = 8;
127+
type Mask = Self;
125128

126129
#[inline]
127130
#[target_feature(enable = "avx2")]
@@ -163,6 +166,7 @@ impl From<__m128i> for __m64i {
163166

164167
impl Vector for __m128i {
165168
const LANES: usize = 16;
169+
type Mask = Self;
166170

167171
#[inline]
168172
#[target_feature(enable = "avx2")]
@@ -197,6 +201,7 @@ impl Vector for __m128i {
197201

198202
impl Vector for __m256i {
199203
const LANES: usize = 32;
204+
type Mask = Self;
200205

201206
#[inline]
202207
#[target_feature(enable = "avx2")]

0 commit comments

Comments
 (0)