Skip to content

Specialize ASCII case for substr() #12444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 87 additions & 20 deletions datafusion/functions/src/unicode/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
// under the License.

use std::any::Any;
use std::cmp::max;
use std::sync::Arc;

use crate::string::common::StringArrayType;
use crate::utils::{make_scalar_function, utf8_to_str_type};
use arrow::array::{
make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView,
GenericStringArray, OffsetSizeTrait, StringViewArray,
make_view, Array, ArrayIter, ArrayRef, AsArray, ByteView, GenericStringArray,
OffsetSizeTrait, StringViewArray,
};
use arrow::datatypes::DataType;
use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
use datafusion_common::cast::as_int64_array;
use datafusion_common::{exec_datafusion_err, exec_err, Result};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

Expand Down Expand Up @@ -119,19 +119,27 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
}

// Convert the given `start` and `count` to valid byte indices within `input` string
//
// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)`
// `start` is 1-based, if `count` is not provided count to the end of the string
// Input indices are character-based, and return values are byte indices
// The input bounds can be outside string bounds, this function will return
// the intersection between input bounds and valid string bounds
// `input_ascii_only` is used to optimize this function if `input` is ASCII-only
//
// * Example
// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, usize) {
let start = start - 1;
fn get_true_start_end(
input: &str,
start: i64,
count: Option<u64>,
is_input_ascii_only: bool,
) -> (usize, usize) {
let start = start.checked_sub(1).unwrap_or(start);

let end = match count {
Some(count) => start + count as i64,
None => input.len() as i64,
Expand All @@ -142,6 +150,14 @@ fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, us
let end = end.clamp(0, input.len() as i64) as usize;
let count = end - start;

// If input is ASCII-only, byte-based indices equals to char-based indices
if is_input_ascii_only {
return (start, end);
}

// Otherwise, calculate byte indices from char indices
// Note this decoding is relatively expensive for this simple `substr` function,,
// so the implementation attempts to decode in one pass (and caused the complexity)
let (mut st, mut ed) = (input.len(), input.len());
let mut start_counting = false;
let mut cnt = 0;
Expand Down Expand Up @@ -197,6 +213,29 @@ fn string_view_substr(

let start_array = as_int64_array(&args[0])?;

// Notes for ASCII-only optimization:
//
// String characters are variable length encoded in UTF-8, `substr()` function's
// arguments are character-based, converting them into byte-based indices
// requires expensive decoding.
// However, checking if a string is ASCII-only is relatively cheap.
// If strings are ASCII only, use byte-based indices instead.
//
// A common pattern to call `substr()` is taking a small prefix of a long
// string, such as `substr(long_str_with_1k_chars, 1, 32)`.
// In such case the overhead of ASCII-validation may not be worth it, so
// skip the validation for long strings for now.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not check only the requested string prefix for being ascii?
could string_view_array.is_ascii variant validate string prefixes of given length why still being vectorized?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite sure if it is the same question that @findepi is asking, but I wonder if we could get back the performance loss by also using the information on the # bytes are we requesting? Like if the prefix length is less than 32 say, don't bother checking for ascii. 🤔

I thinking short prefixes are likely common (looking for http:// as a url prefix, for example). 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not check only the requested string prefix for being ascii? could string_view_array.is_ascii variant validate string prefixes of given length why still being vectorized?

I think it's a good idea for the current situation
However in the long term we might use an alternative approach: do validation when reading arrays from storage to memory, and cache this is_ascii property within the arrow array (as suggested by @alamb #12444 (review))

// TODO: A better heuristic is to use the ratio to decide whether to validate
// like `(start + count) / estimate_avg_strlen > threshold`, but it requires
// specialized implementation for `ScalarValue` input.
let estimate_avg_strlen =
string_view_array.get_buffer_memory_size() / string_view_array.len();
let enable_ascii_fast_path = if estimate_avg_strlen > 256 {
false // Skip ASCII validation
} else {
string_view_array.is_ascii()
};

// In either case of `substr(s, i)` or `substr(s, i, cnt)`
// If any of input argument is `NULL`, the result is `NULL`
match args.len() {
Expand All @@ -207,7 +246,8 @@ fn string_view_substr(
.zip(start_array.iter())
{
if let (Some(str), Some(start)) = (str_opt, start_opt) {
let (start, end) = get_true_start_end(str, start, None);
let (start, end) =
get_true_start_end(str, start, None, enable_ascii_fast_path);
let substr = &str[start..end];

make_and_append_view(
Expand Down Expand Up @@ -239,8 +279,17 @@ fn string_view_substr(
"negative substring length not allowed: substr(<str>, {start}, {count})"
);
} else {
let (start, end) =
get_true_start_end(str, start, Some(count as u64));
if start == i64::MIN {
return exec_err!(
"negative overflow when calculating skip value"
);
}
let (start, end) = get_true_start_end(
str,
start,
Some(count as u64),
enable_ascii_fast_path,
);
let substr = &str[start..end];

make_and_append_view(
Expand Down Expand Up @@ -283,9 +332,18 @@ fn string_view_substr(

fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
V: StringArrayType<'a>,
T: OffsetSizeTrait,
{
// Notes for ASCII-only optimization:
// see comment in `string_view_substr()`
let estimate_avg_strlen = string_array.get_buffer_memory_size() / string_array.len();
let enable_ascii_fast_path = if estimate_avg_strlen > 256 {
false // Skip ASCII validation
} else {
string_array.is_ascii()
};

match args.len() {
1 => {
let iter = ArrayIter::new(string_array);
Expand All @@ -295,11 +353,14 @@ where
.zip(start_array.iter())
.map(|(string, start)| match (string, start) {
(Some(string), Some(start)) => {
if start <= 0 {
Some(string.to_string())
} else {
Some(string.chars().skip(start as usize - 1).collect())
}
let (start, end) = get_true_start_end(
string,
start,
None,
enable_ascii_fast_path,
); // start, end is byte-based
let substr = &string[start..end];
Some(substr.to_string())
}
_ => None,
})
Expand All @@ -322,11 +383,17 @@ where
"negative substring length not allowed: substr(<str>, {start}, {count})"
)
} else {
let skip = max(0, start.checked_sub(1).ok_or_else(
|| exec_datafusion_err!("negative overflow when calculating skip value")
)?);
let count = max(0, count + (if start < 1 { start - 1 } else { 0 }));
Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::<String>()))
if start == i64::MIN {
return exec_err!("negative overflow when calculating skip value")
}
let (start, end) = get_true_start_end(
string,
start,
Some(count as u64),
enable_ascii_fast_path,
); // start, end is byte-based
let substr = &string[start..end];
Ok(Some(substr.to_string()))
}
}
_ => Ok(None),
Expand Down