Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "ahocorasick_rs"
version = "0.22.2"
edition = "2021"
edition = "2024"
authors = ["G-Research <c.rindi@gr-oss.io>", "Itamar Turner-Trauring <itamar@pythonspeed.com>"]
description = "Search a string for multiple substrings at once"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[toolchain]
channel = "1.84.1"
channel = "1.85.1"
components = ["rustfmt", "clippy"]
79 changes: 42 additions & 37 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,35 +143,38 @@ impl PyAhoCorasick {
let patterns_error: Cell<Option<PyErr>> = Cell::new(None);

// Convert the `patterns` iterable into an Iterator over Py<PyString>:
let mut patterns_iter = patterns.iter()?.map_while(|pat| {
pat.and_then(|i| i.downcast_into::<PyString>().map_err(PyErr::from).map(|i|i.into_py(py)))
.map_or_else(
|e| {
patterns_error.set(Some(e));
None
},
Some::<Py<PyString>>,
)
let mut patterns_iter = patterns.try_iter()?.map_while(|pat| {
pat.and_then(|i| {
i.downcast_into::<PyString>()
.map_err(PyErr::from)
.map(|i| i.unbind())
})
.map_or_else(
|e| {
patterns_error.set(Some(e));
None
},
Some::<Py<PyString>>,
)
});

// If store_patterns is None (the default), use a heuristic to decide
// whether to store patterns.
let mut patterns: Vec<Py<PyString>> = vec![];
let store_patterns = store_patterns
.unwrap_or_else(|| {
let mut total = 0;
let mut store_patterns = true;
for s in patterns_iter.by_ref() {
// Highly unlikely that strings will fail to return length, so just expect().
total += s.bind(py).len().expect("String doesn't have length?");
patterns.push(s);
if total > 4096 {
store_patterns = false;
break;
}
let store_patterns = store_patterns.unwrap_or_else(|| {
let mut total = 0;
let mut store_patterns = true;
for s in patterns_iter.by_ref() {
// Highly unlikely that strings will fail to return length, so just expect().
total += s.bind(py).len().expect("String doesn't have length?");
patterns.push(s);
if total > 4096 {
store_patterns = false;
break;
}
store_patterns
});
}
store_patterns
});

if store_patterns {
for s in patterns_iter.by_ref() {
Expand All @@ -183,7 +186,8 @@ impl PyAhoCorasick {
.kind(implementation.map(|i| i.into()))
.match_kind(matchkind.into())
.build(
patterns.clone()
patterns
.clone()
.into_iter()
.chain(patterns_iter)
.chunks(10 * 1024)
Expand Down Expand Up @@ -246,23 +250,24 @@ impl PyAhoCorasick {
/// Return matches as list of patterns (i.e. strings). If ``overlapping`` is
/// ``False`` (the default), don't include overlapping results.
#[pyo3(signature = (haystack, overlapping = false))]
fn find_matches_as_strings(
self_: PyRef<Self>,
haystack: &str,
fn find_matches_as_strings<'py>(
self_: PyRef<'py, Self>,
haystack: &'py str,
overlapping: bool,
) -> PyResult<Py<PyList>> {
) -> PyResult<Bound<'py, PyList>> {
let py = self_.py();
let matches = get_matches(&self_.ac_impl, haystack.as_bytes(), overlapping)?;
let matches = py.allow_threads(|| matches.collect::<Vec<_>>().into_iter());
let result = if let Some(ref patterns) = self_.patterns {
PyList::new_bound(py, matches.map(|m| patterns[m.pattern()].clone_ref(py)))
} else {
PyList::new_bound(
let result = match self_.patterns {
Some(ref patterns) => {
PyList::new(py, matches.map(|m| patterns[m.pattern()].clone_ref(py)))
}
_ => PyList::new(
py,
matches.map(|m| PyString::new_bound(py, &haystack[m.start()..m.end()])),
)
matches.map(|m| PyString::new(py, &haystack[m.start()..m.end()])),
),
};
Ok(result.into())
result
}
}

Expand All @@ -277,7 +282,7 @@ impl<'py> TryFrom<Bound<'py, PyAny>> for PyBufferBytes<'py> {

// Get a PyBufferBytes from a Python object
fn try_from(obj: Bound<'py, PyAny>) -> PyResult<Self> {
let buffer = PyBuffer::<u8>::get_bound(&obj).map_err(PyErr::from)?;
let buffer = PyBuffer::<u8>::get(&obj).map_err(PyErr::from)?;

if buffer.dimensions() > 1 {
return Err(PyTypeError::new_err(
Expand Down Expand Up @@ -366,7 +371,7 @@ impl PyBytesAhoCorasick {
// Convert the `patterns` iterable into an Iterator over PyBufferBytes
let patterns_iter =
patterns
.iter()?
.try_iter()?
.map_while(|pat| match pat.and_then(PyBufferBytes::try_from) {
Ok(pat) => {
if pat.as_ref().is_empty() {
Expand Down