Skip to content

Commit 5b81815

Browse files
zakcutnermarmeladema
authored andcommitted
Use constant generics for specialized memcmp
1 parent d0e4299 commit 5b81815

File tree

2 files changed

+18
-180
lines changed

2 files changed

+18
-180
lines changed

src/memcmp.rs

Lines changed: 4 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -1,173 +1,11 @@
1-
#![allow(dead_code)]
2-
31
use std::slice;
42

53
#[inline]
6-
pub unsafe fn memcmp0(_: *const u8, _: *const u8, n: usize) -> bool {
7-
debug_assert_eq!(n, 0);
8-
true
9-
}
10-
11-
#[inline]
12-
pub unsafe fn memcmp1(left: *const u8, right: *const u8, n: usize) -> bool {
13-
debug_assert_eq!(n, 1);
14-
*left == *right
15-
}
16-
17-
#[inline]
18-
pub unsafe fn memcmp2(left: *const u8, right: *const u8, n: usize) -> bool {
19-
debug_assert_eq!(n, 2);
20-
let left = left.cast::<u16>();
21-
let right = right.cast::<u16>();
22-
left.read_unaligned() == right.read_unaligned()
23-
}
24-
25-
#[inline]
26-
pub unsafe fn memcmp3(left: *const u8, right: *const u8, n: usize) -> bool {
27-
debug_assert_eq!(n, 3);
28-
memcmp2(left, right, 2) && memcmp1(left.add(2), right.add(2), 1)
29-
}
30-
31-
#[inline]
32-
pub unsafe fn memcmp4(left: *const u8, right: *const u8, n: usize) -> bool {
33-
debug_assert_eq!(n, 4);
34-
let left = left.cast::<u32>();
35-
let right = right.cast::<u32>();
36-
left.read_unaligned() == right.read_unaligned()
37-
}
38-
39-
#[inline]
40-
pub unsafe fn memcmp5(left: *const u8, right: *const u8, n: usize) -> bool {
41-
debug_assert_eq!(n, 5);
42-
memcmp4(left, right, 4) && memcmp1(left.add(4), right.add(4), 1)
43-
}
44-
45-
#[inline]
46-
pub unsafe fn memcmp6(left: *const u8, right: *const u8, n: usize) -> bool {
47-
debug_assert_eq!(n, 6);
48-
memcmp4(left, right, 4) && memcmp2(left.add(4), right.add(4), 2)
49-
}
50-
51-
#[inline]
52-
pub unsafe fn memcmp7(left: *const u8, right: *const u8, n: usize) -> bool {
53-
debug_assert_eq!(n, 7);
54-
memcmp4(left, right, 4) && memcmp3(left.add(4), right.add(4), 3)
55-
}
56-
57-
#[inline]
58-
pub unsafe fn memcmp8(left: *const u8, right: *const u8, n: usize) -> bool {
59-
debug_assert_eq!(n, 8);
60-
let left = left.cast::<u64>();
61-
let right = right.cast::<u64>();
62-
left.read_unaligned() == right.read_unaligned()
63-
}
64-
65-
#[inline]
66-
pub unsafe fn memcmp9(left: *const u8, right: *const u8, n: usize) -> bool {
67-
debug_assert_eq!(n, 9);
68-
memcmp8(left, right, 8) && memcmp1(left.add(8), right.add(8), 1)
69-
}
70-
71-
#[inline]
72-
pub unsafe fn memcmp10(left: *const u8, right: *const u8, n: usize) -> bool {
73-
debug_assert_eq!(n, 10);
74-
memcmp8(left, right, 8) && memcmp2(left.add(8), right.add(8), 2)
75-
}
76-
77-
#[inline]
78-
pub unsafe fn memcmp11(left: *const u8, right: *const u8, n: usize) -> bool {
79-
debug_assert_eq!(n, 11);
80-
memcmp8(left, right, 8) && memcmp3(left.add(8), right.add(8), 3)
81-
}
82-
83-
#[inline]
84-
pub unsafe fn memcmp12(left: *const u8, right: *const u8, n: usize) -> bool {
85-
debug_assert_eq!(n, 12);
86-
memcmp8(left, right, 8) && memcmp4(left.add(8), right.add(8), 4)
87-
}
88-
89-
#[inline]
90-
pub unsafe fn memcmp(left: *const u8, right: *const u8, n: usize) -> bool {
4+
pub unsafe fn generic(left: *const u8, right: *const u8, n: usize) -> bool {
915
slice::from_raw_parts(left, n) == slice::from_raw_parts(right, n)
926
}
937

94-
#[cfg(test)]
95-
mod tests {
96-
fn memcmp(f: unsafe fn(*const u8, *const u8, usize) -> bool, n: usize) {
97-
let left = vec![b'0'; n];
98-
unsafe { assert!(f(left.as_ptr(), left.as_ptr(), n)) };
99-
unsafe { assert!(super::memcmp(left.as_ptr(), left.as_ptr(), n)) };
100-
101-
for i in 0..n {
102-
let mut right = left.clone();
103-
right[i] = b'1';
104-
unsafe { assert!(!f(left.as_ptr(), right.as_ptr(), n)) };
105-
unsafe { assert!(!super::memcmp(left.as_ptr(), right.as_ptr(), n)) };
106-
}
107-
}
108-
109-
#[test]
110-
fn memcmp0() {
111-
memcmp(super::memcmp0, 0);
112-
}
113-
114-
#[test]
115-
fn memcmp1() {
116-
memcmp(super::memcmp1, 1);
117-
}
118-
119-
#[test]
120-
fn memcmp2() {
121-
memcmp(super::memcmp2, 2);
122-
}
123-
124-
#[test]
125-
fn memcmp3() {
126-
memcmp(super::memcmp3, 3);
127-
}
128-
129-
#[test]
130-
fn memcmp4() {
131-
memcmp(super::memcmp4, 4);
132-
}
133-
134-
#[test]
135-
fn memcmp5() {
136-
memcmp(super::memcmp5, 5);
137-
}
138-
139-
#[test]
140-
fn memcmp6() {
141-
memcmp(super::memcmp6, 6);
142-
}
143-
144-
#[test]
145-
fn memcmp7() {
146-
memcmp(super::memcmp7, 7);
147-
}
148-
149-
#[test]
150-
fn memcmp8() {
151-
memcmp(super::memcmp8, 8);
152-
}
153-
154-
#[test]
155-
fn memcmp9() {
156-
memcmp(super::memcmp9, 9);
157-
}
158-
159-
#[test]
160-
fn memcmp10() {
161-
memcmp(super::memcmp10, 10);
162-
}
163-
164-
#[test]
165-
fn memcmp11() {
166-
memcmp(super::memcmp11, 11);
167-
}
168-
169-
#[test]
170-
fn memcmp12() {
171-
memcmp(super::memcmp12, 12);
172-
}
8+
#[inline]
9+
pub unsafe fn specialized<const N: usize>(left: *const u8, right: *const u8) -> bool {
10+
slice::from_raw_parts(left, N) == slice::from_raw_parts(right, N)
17311
}

src/x86.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -282,20 +282,20 @@ impl<N: Needle> Avx2Searcher<N> {
282282
let chunk = chunk.add(eq.trailing_zeros() as usize);
283283
let equal = match N::SIZE {
284284
Some(0) => unreachable!(),
285-
Some(1) => memcmp::memcmp0(chunk, needle, 0),
286-
Some(2) => memcmp::memcmp1(chunk, needle, 1),
287-
Some(3) => memcmp::memcmp2(chunk, needle, 2),
288-
Some(4) => memcmp::memcmp3(chunk, needle, 3),
289-
Some(5) => memcmp::memcmp4(chunk, needle, 4),
290-
Some(6) => memcmp::memcmp5(chunk, needle, 5),
291-
Some(7) => memcmp::memcmp6(chunk, needle, 6),
292-
Some(8) => memcmp::memcmp7(chunk, needle, 7),
293-
Some(9) => memcmp::memcmp8(chunk, needle, 8),
294-
Some(10) => memcmp::memcmp9(chunk, needle, 9),
295-
Some(11) => memcmp::memcmp10(chunk, needle, 10),
296-
Some(12) => memcmp::memcmp11(chunk, needle, 11),
297-
Some(13) => memcmp::memcmp12(chunk, needle, 12),
298-
_ => memcmp::memcmp(chunk, needle, self.needle.size() - 1),
285+
Some(1) => memcmp::specialized::<0>(chunk, needle),
286+
Some(2) => memcmp::specialized::<1>(chunk, needle),
287+
Some(3) => memcmp::specialized::<2>(chunk, needle),
288+
Some(4) => memcmp::specialized::<3>(chunk, needle),
289+
Some(5) => memcmp::specialized::<4>(chunk, needle),
290+
Some(6) => memcmp::specialized::<5>(chunk, needle),
291+
Some(7) => memcmp::specialized::<6>(chunk, needle),
292+
Some(8) => memcmp::specialized::<7>(chunk, needle),
293+
Some(9) => memcmp::specialized::<8>(chunk, needle),
294+
Some(10) => memcmp::specialized::<9>(chunk, needle),
295+
Some(11) => memcmp::specialized::<10>(chunk, needle),
296+
Some(12) => memcmp::specialized::<11>(chunk, needle),
297+
Some(13) => memcmp::specialized::<12>(chunk, needle),
298+
_ => memcmp::generic(chunk, needle, self.needle.size() - 1),
299299
};
300300
if equal {
301301
return true;

0 commit comments

Comments
 (0)