From 45b45c3d68d2b15078a654c83b874c1d5a2b7f9b Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Sat, 3 May 2025 01:25:35 +0530 Subject: [PATCH 01/11] Implementation for regex_instr --- datafusion/functions/benches/regx.rs | 41 + datafusion/functions/src/regex/mod.rs | 31 + datafusion/functions/src/regex/regexpinstr.rs | 995 ++++++++++++++++++ .../test_files/regexp/regexp_instr.slt | 182 ++++ .../source/user-guide/sql/scalar_functions.md | 35 + 5 files changed, 1284 insertions(+) create mode 100644 datafusion/functions/src/regex/regexpinstr.rs create mode 100644 datafusion/sqllogictest/test_files/regexp/regexp_instr.slt diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 3a1a6a71173e..e7cba5c5e6b6 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -23,6 +23,7 @@ use arrow::compute::cast; use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_functions::regex::regexpcount::regexp_count_func; +use datafusion_functions::regex::regexpcount::regexp_instr; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -127,6 +128,46 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("regexp_instr_1000 string", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_instr_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_instr should work on utf8"), + ) + }) + }); + + c.bench_function("regexp_instr_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_instr_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_instr should work on utf8view"), + ) + }) + }); + c.bench_function("regexp_like_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 13fbc049af58..b0278ac9df61 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -20,12 +20,14 @@ use std::sync::Arc; pub mod regexpcount; +pub mod regexpinstr; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; // create UDFs make_udf_function!(regexpcount::RegexpCountFunc, regexp_count); +make_udf_function!(regexpinstr::RegexpInstrFunc, regexp_instr); make_udf_function!(regexpmatch::RegexpMatchFunc, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, regexp_like); make_udf_function!(regexpreplace::RegexpReplaceFunc, regexp_replace); @@ -60,6 +62,34 @@ pub mod expr_fn { super::regexp_match().call(args) } + /// Returns index of regular expression matches in a string. + pub fn regexp_instr( + values: Expr, + regex: Expr, + start: Option, + n: Option, + endoption: Option, + flags: Option, + subexpr: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + if let Some(n) = n { + args.push(n); + }; + if let Some(endoption) = endoption { + args.push(endoption); + }; + if let Some(flags) = flags { + args.push(flags); + }; + if let Some(subexpr) = subexpr { + args.push(subexpr); + }; + super::regexp_instr().call(args) + } /// Returns true if a has at least one match in a string, false otherwise. pub fn regexp_like(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; @@ -89,6 +119,7 @@ pub fn functions() -> Vec> { vec![ regexp_count(), regexp_match(), + regexp_instr(), regexp_like(), regexp_replace(), ] diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs new file mode 100644 index 000000000000..8fb6c585b6f7 --- /dev/null +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -0,0 +1,995 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{ + DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, +}; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, + TypeSignature::Uniform, Volatility, +}; +use datafusion_macros::user_doc; +use itertools::izip; +use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::Arc; + +#[user_doc( + doc_section(label = "Regular Expression Functions"), + description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.", + syntax_example = "regexp_instr(str, regexp[, start, N, endoption, flags])", + sql_example = r#"```sql +> SELECT regexp_instr('ABCDEF', 'c(.)(..)', 1, 1, 0, 'i', 2); ++---------------------------------------------------------------+ +| regexp_instr() | ++---------------------------------------------------------------+ +| 5 | ++---------------------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + standard_argument(name = "regexp", prefix = "Regular"), + argument( + name = "start", + description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1" + ), + argument( + name = "N", + description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function." + ), + argument( + name = "endoption", + description = "- **endoption**: Optional. If 0, returns the starting position of the match (default). If 1, returns the ending position of the match. Can be a constant, column, or function." + ), + argument( + name = "flags", + description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"# + ), + argument( + name = "subexpr", + description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match." + ) +)] +#[derive(Debug)] +pub struct RegexpInstrFunc { + signature: Signature, +} + +impl Default for RegexpInstrFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpInstrFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Uniform(2, vec![Utf8View, LargeUtf8, Utf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Utf8View]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + Exact(vec![Utf8, Utf8, Int64, Utf8]), + Exact(vec![Utf8, Utf8, Int64, Int64, Int64, Utf8]), + Exact(vec![Utf8, Utf8, Int64, Int64, Int64, Utf8, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpInstrFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_instr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + let args = &args.args; + + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.to_array(inferred_length)) + .collect::>>()?; + + let result = regexp_instr_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +pub fn regexp_instr_func(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + if !(2..=7).contains(&args_len) { + return exec_err!("regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 7."); + } + + let values = &args[0]; + match values.data_type() { + Utf8 | LargeUtf8 | Utf8View => (), + other => { + return internal_err!( + "Unsupported data type {other:?} for function regexp_instr" + ); + } + } + + regexp_instr( + values, + &args[1], + if args_len > 2 { Some(&args[2]) } else { None }, + if args_len > 3 { Some(&args[3]) } else { None }, + if args_len > 4 { Some(&args[4]) } else { None }, + if args_len > 5 { Some(&args[5]) } else { None }, + if args_len > 6 { Some(&args[6]) } else { None }, + ) + .map_err(|e| e.into()) +} + +/// `arrow-rs` style implementation of `regexp_instr` function. +/// This function `regexp_instr` is responsible for returning the index of a regular expression pattern +/// within a string array. It supports optional start positions and flags for case insensitivity. +/// +/// The function accepts a variable number of arguments: +/// - `values`: The array of strings to search within. +/// - `regex_array`: The array of regular expression patterns to search for. +/// - `start_array` (optional): The array of start positions for the search. +/// - `nth_array` (optional): The array of start nth for the search. +/// - `endoption_array` (optional): The array of endoption positions for the search. +/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity). +/// - `subexpr_array` (optional): The array of subexpr positions for the search. +/// +/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions, +/// and flags. It uses a cache to store compiled regular expressions for efficiency. +/// +/// # Errors +/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. +pub fn regexp_instr( + values: &dyn Array, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + nth_array: Option<&dyn Datum>, + endoption_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, + subexpr_array: Option<&dyn Datum>, +) -> Result { + let (regex_array, is_regex_scalar) = regex_array.get(); + let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| { + let (start, is_start_scalar) = start.get(); + (Some(start), is_start_scalar) + }); + let (nth_array, is_nth_scalar) = nth_array.map_or((None, true), |nth| { + let (nth, is_nth_scalar) = nth.get(); + (Some(nth), is_nth_scalar) + }); + let (endoption_array, is_endoption_scalar) = + endoption_array.map_or((None, true), |endoption| { + let (endoption, is_endoption_scalar) = endoption.get(); + (Some(endoption), is_endoption_scalar) + }); + let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), is_flags_scalar) + }); + let (subexpr_array, is_subexpr_scalar) = + subexpr_array.map_or((None, true), |subexpr| { + let (subexpr, is_subexpr_scalar) = subexpr.get(); + (Some(subexpr), is_subexpr_scalar) + }); + + match (values.data_type(), regex_array.data_type(), flags_array) { + (Utf8, Utf8, None) => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + endoption_array.map(|endoption| endoption.as_primitive::()), + is_endoption_scalar, + None, + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + endoption_array.map(|endoption| endoption.as_primitive::()), + is_endoption_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (LargeUtf8, LargeUtf8, None) => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + endoption_array.map(|endoption| endoption.as_primitive::()), + is_endoption_scalar, + None, + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + endoption_array.map(|endoption| endoption.as_primitive::()), + is_endoption_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (Utf8View, Utf8View, None) => regexp_instr_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + endoption_array.map(|endoption| endoption.as_primitive::()), + is_endoption_scalar, + None, + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + endoption_array.map(|endoption| endoption.as_primitive::()), + is_endoption_scalar, + Some(flags_array.as_string_view()), + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + _ => Err(ArrowError::ComputeError( + "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(), + )), + } +} + +enum ScalarOrArray { + Scalar(T), + Array(Vec), +} + +impl ScalarOrArray { + fn iter(&self, len: usize) -> Box + '_> { + match self { + ScalarOrArray::Scalar(val) => { + Box::new(std::iter::repeat(val.clone()).take(len)) + } + ScalarOrArray::Array(arr) => Box::new(arr.iter().cloned()), + } + } +} + +pub fn regexp_instr_inner<'a, S>( + values: S, + regex_array: S, + is_regex_scalar: bool, + start_array: Option<&Int64Array>, + is_start_scalar: bool, + nth_array: Option<&Int64Array>, + is_nth_scalar: bool, + endoption_array: Option<&Int64Array>, + is_endoption_scalar: bool, + flags_array: Option, + is_flags_scalar: bool, + subexp_array: Option<&Int64Array>, + is_subexp_scalar: bool, +) -> Result +where + S: StringArrayType<'a>, +{ + let len = values.len(); + let regex_input = if is_regex_scalar || regex_array.len() == 1 { + ScalarOrArray::Scalar(regex_array.value(0)) + } else { + let regex_vec: Vec<&str> = regex_array.iter().map(|v| v.unwrap_or("")).collect(); + ScalarOrArray::Array(regex_vec) + }; + + let start_input = if let Some(start) = start_array { + if is_start_scalar || start.len() == 1 { + ScalarOrArray::Scalar(start.value(0)) + } else { + let start_vec: Vec = (0..start.len()) + .map(|i| if start.is_null(i) { 0 } else { start.value(i) }) // handle nulls as 0 + .collect(); + + ScalarOrArray::Array(start_vec) + } + } else { + if len == 1 { + ScalarOrArray::Scalar(1) + } else { + ScalarOrArray::Array(vec![1; len]) + } + // Default start = 1 + }; + + let nth_input = if let Some(nth) = nth_array { + if is_nth_scalar || nth.len() == 1 { + ScalarOrArray::Scalar(nth.value(0)) + } else { + let nth_vec: Vec = (0..nth.len()) + .map(|i| if nth.is_null(i) { 0 } else { nth.value(i) }) // handle nulls as 0 + .collect(); + ScalarOrArray::Array(nth_vec) + } + } else { + if len == 1 { + ScalarOrArray::Scalar(1) + } + // Default nth = 0 + else { + ScalarOrArray::Array(vec![1; len]) + } + }; + + let endoption_input = if let Some(endoption) = endoption_array { + if is_endoption_scalar || endoption.len() == 1 { + ScalarOrArray::Scalar(endoption.value(0)) + } else { + let endoption_vec: Vec = (0..endoption.len()) + .map(|i| { + if endoption.is_null(i) { + 0 + } else { + endoption.value(i) + } + }) // handle nulls as 0 + .collect(); + ScalarOrArray::Array(endoption_vec) + } + } else { + if len == 1 { + ScalarOrArray::Scalar(0) + } + // Default nth = 0 + else { + ScalarOrArray::Array(vec![0; len]) + } // Default endoption = 0 + }; + + let flags_input = if let Some(ref flags) = flags_array { + if is_flags_scalar || flags.len() == 1 { + ScalarOrArray::Scalar(flags.value(0)) + } else { + let flags_vec: Vec<&str> = flags.iter().map(|v| v.unwrap_or("")).collect(); + ScalarOrArray::Array(flags_vec) + } + } else { + if len == 1 { + ScalarOrArray::Scalar("") + } + // Default flags = "" + else { + ScalarOrArray::Array(vec![""; len]) + } // Default flags = "" + }; + + let subexp_input = if let Some(subexp) = subexp_array { + if is_subexp_scalar || subexp.len() == 1 { + ScalarOrArray::Scalar(subexp.value(0)) + } else { + let subexp_vec: Vec = (0..subexp.len()) + .map(|i| { + if subexp.is_null(i) { + 0 + } else { + subexp.value(i) + } + }) // handle nulls as 0 + .collect(); + ScalarOrArray::Array(subexp_vec) + } + } else { + if len == 1 { + ScalarOrArray::Scalar(0) + } + // Default subexp = 0 + else { + ScalarOrArray::Array(vec![0; len]) + } + }; + + let mut regex_cache = HashMap::new(); + + let result: Result, ArrowError> = izip!( + values.iter(), + regex_input.iter(len), + start_input.iter(len), + nth_input.iter(len), + endoption_input.iter(len), + flags_input.iter(len), + subexp_input.iter(len) + ) + .map(|(value, regex, start, nth, endoption, flags, subexp)| { + if regex.is_empty() { + return Ok(0); + } + + let pattern = compile_and_cache_regex(®ex, Some(flags), &mut regex_cache)?; + + get_index(value, &pattern, start, nth, endoption, subexp) + }) + .collect(); + + Ok(Arc::new(Int64Array::from(result?))) +} + +fn compile_and_cache_regex<'strings, 'cache>( + regex: &'strings str, + flags: Option<&'strings str>, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, +) -> Result<&'cache Regex, ArrowError> +where + 'strings: 'cache, +{ + let result = match regex_cache.entry((regex, flags)) { + Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), + Entry::Vacant(vacant_entry) => { + let compiled = compile_regex(regex, flags)?; + vacant_entry.insert(compiled) + } + }; + Ok(result) +} + +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_instr() does not support global flag".to_string(), + )); + } + format!("(?{}){}", flags, regex) + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + )) + }) +} + +fn get_index( + value: Option<&str>, + pattern: &Regex, + start: i64, + n: i64, + endoption: i64, + subexpr: i64, +) -> Result { + let value = match value { + None | Some("") => return Ok(0), + Some(value) => value, + }; + + // let start = start.unwrap_or(1); + if start < 1 { + return Err(ArrowError::ComputeError( + "regexp_instr() requires start to be 1-based".to_string(), + )); + } + + // let n = n.unwrap_or(1); // Default to finding the first match + if n < 1 { + return Err(ArrowError::ComputeError( + "N must be 1 or greater".to_string(), + )); + } + + let find_slice = value.chars().skip(start as usize - 1).collect::(); + let matches: Vec<_> = pattern.find_iter(&find_slice).collect(); + + let mut result_index = 0; + if matches.len() < n as usize { + return Ok(result_index); // Return 0 if the N-th match was not found + } else { + let nth_match = matches.get((n - 1) as usize).ok_or_else(|| { + ArrowError::ComputeError("N-th match not found".to_string()) + })?; + + let match_start = nth_match.start() as i64 + start; + let match_end = nth_match.end() as i64 + start; + + result_index = match endoption { + 1 => match_end, // Return end position of match + _ => match_start, // Default: Return start position + }; + } + // Find the N-th match (1-based index) + + // Handle subexpression capturing (if requested) + + if subexpr > 0 { + if let Some(captures) = pattern.captures(&find_slice) { + if let Some(matched) = captures.get(subexpr as usize) { + return Ok(matched.start() as i64 + start); + } + } + return Ok(0); // Return 0 if the subexpression was not found + } + + Ok(result_index) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use arrow::array::{GenericStringArray, StringViewArray}; + use arrow::datatypes::Field; + use datafusion_expr::ScalarFunctionArgs; + #[test] + fn test_regexp_instr() { + test_case_sensitive_regexp_instr_scalar(); + test_case_sensitive_regexp_instr_scalar_start(); + test_case_sensitive_regexp_instr_scalar_nth(); + test_case_sensitive_regexp_instr_scalar_endoption(); + + test_case_sensitive_regexp_instr_array::>(); + test_case_sensitive_regexp_instr_array::>(); + test_case_sensitive_regexp_instr_array::(); + + test_case_sensitive_regexp_instr_array_start::>(); + test_case_sensitive_regexp_instr_array_start::>(); + test_case_sensitive_regexp_instr_array_start::(); + + test_case_sensitive_regexp_instr_array_nth::>(); + test_case_sensitive_regexp_instr_array_nth::>(); + test_case_sensitive_regexp_instr_array_nth::(); + + test_case_sensitive_regexp_instr_array_endoption::>(); + test_case_sensitive_regexp_instr_array_endoption::>(); + test_case_sensitive_regexp_instr_array_endoption::(); + } + + fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result { + let args_values = args + .iter() + .map(|sv| ColumnarValue::Scalar(sv.clone())) + .collect(); + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_values, + arg_fields, + number_rows: args.len(), + return_field: &Field::new("f", Int64, true), + }) + } + + fn test_case_sensitive_regexp_instr_scalar() { + let values = [ + "hello world", + "abcdefg", + "xyz123xyz", + "no match here", + "", + "abc", + "", + ]; + let regex = ["o", "d", "123", "z", "gg", "", ""]; + + let expected: Vec = vec![5, 4, 4, 0, 0, 0, 0]; + + izip!(values.iter(), regex.iter()) + .enumerate() + .for_each(|(pos, (&v, &r))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let expected = expected.get(pos).cloned(); + let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]); + // let res_exp = re.unwrap(); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_scalar_start() { + let values = ["abcabcabc", "abcabcabc", ""]; + let regex = ["abc", "abc", "gg"]; + let start = [4, 5, 5]; + let expected: Vec = vec![4, 7, 0]; + + izip!(values.iter(), regex.iter(), start.iter()) + .enumerate() + .for_each(|(pos, (&v, &r, &s))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let expected = expected.get(pos).cloned(); + let re = + regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let re = + regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let re = + regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_scalar_nth() { + let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]; + let regex = ["abc", "abc", "abc", "abc"]; + let start = [1, 1, 1, 1]; + let nth = [1, 2, 3, 4]; + let expected: Vec = vec![1, 4, 7, 0]; + + izip!(values.iter(), regex.iter(), start.iter(), nth.iter()) + .enumerate() + .for_each(|(pos, (&v, &r, &s, &n))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let expected = expected.get(pos).cloned(); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_scalar_endoption() { + let values = ["abcdefg", "abcdefg"]; + let regex = ["cd", "cd"]; + let start = [1, 1]; + let nth = [1, 1]; + let endoption = [0, 1]; + let expected: Vec = vec![3, 5]; + + izip!( + values.iter(), + regex.iter(), + start.iter(), + nth.iter(), + endoption.iter() + ) + .enumerate() + .for_each(|(pos, (&v, &r, &s, &n, &e))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let endoption_sv = ScalarValue::Int64(Some(e)); + let expected = expected.get(pos).cloned(); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + endoption_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let endoption_sv = ScalarValue::Int64(Some(e)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + endoption_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let endoption_sv = ScalarValue::Int64(Some(e)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + endoption_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_array() + where + A: From> + Array + 'static, + { + let values = A::from(vec![ + "hello world", + "abcdefg", + "xyz123xyz", + "no match here", + "", + "abc", + "", + ]); + let regex = A::from(vec!["o", "d", "123", "z", "gg", "", ""]); + + let expected = Int64Array::from(vec![5, 4, 4, 0, 0, 0, 0]); + let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_instr_array_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["abcabcabc", "abcabcabc", ""]); + let regex = A::from(vec!["abc", "abc", "gg"]); + let start = Int64Array::from(vec![4, 5, 5]); + let expected = Int64Array::from(vec![4, 7, 0]); + + let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_instr_array_nth() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]); + let regex = A::from(vec!["abc", "abc", "abc", "abc"]); + let start = Int64Array::from(vec![1, 1, 1, 1]); + let nth = Int64Array::from(vec![1, 2, 3, 4]); + let expected = Int64Array::from(vec![1, 4, 7, 0]); + + let re = regexp_instr_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(nth), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_instr_array_endoption() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["abcdefg", "abcdefg"]); + let regex = A::from(vec!["cd", "cd"]); + let start = Int64Array::from(vec![1, 1]); + let nth = Int64Array::from(vec![1, 1]); + let endoption = Int64Array::from(vec![0, 1]); + + let expected = Int64Array::from(vec![3, 5]); + + let re = regexp_instr_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(nth), + Arc::new(endoption), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + +} diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt b/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt new file mode 100644 index 000000000000..021db3dd74a2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query I +SELECT regexp_instr('123123123123123', '(12)3'); +---- +1 + +query I +SELECT regexp_instr('123123123123', '123', 1); +---- +1 + +query I +SELECT regexp_instr('123123123123', '123', 3); +---- +4 + +query I +SELECT regexp_instr('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, 0, ''); +---- +0 + +query I +SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, 0, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_instr() requires start to be 1 based +SELECT regexp_instr('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_instr() requires start to be 1 based +SELECT regexp_instr('123123123123', '123', -3); + +query I +SELECT regexp_instr(str, pattern) FROM regexp_test_data; +---- +0 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_instr(str, pattern, start) FROM regexp_test_data; +---- +0 +1 +1 +0 +0 +0 +0 +0 +3 +4 +1 +2 + + +statement ok +CREATE TABLE t_stringview AS +SELECT + arrow_cast(str, 'Utf8View') AS str, + arrow_cast(pattern, 'Utf8View') AS pattern, + arrow_cast(start, 'Int64') AS start +FROM regexp_test_data; + +query I +SELECT regexp_instr(str, pattern, start) FROM t_stringview; +---- +0 +1 +1 +0 +0 +0 +0 +0 +3 +4 +1 +2 + +query I +SELECT regexp_instr( + arrow_cast(str, 'Utf8'), + arrow_cast(pattern, 'LargeUtf8'), + arrow_cast(start, 'Int32') +) FROM t_stringview; +---- +0 +1 +1 +0 +0 +0 +0 +0 +3 +4 +1 +2 + +query I +SELECT regexp_instr(NULL, NULL); +---- +0 + +query I +SELECT regexp_instr(NULL, 'a'); +---- +0 + +query I +SELECT regexp_instr('a', NULL); +---- +0 + +query I +SELECT regexp_instr(NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int); + +query I +SELECT regexp_instr(str, pattern, start) FROM empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES + ('a', NULL, 1), + (NULL, 'a', 1), + (NULL, NULL, 1), + (NULL, NULL, NULL); + +query I +SELECT regexp_instr(str, pattern, start) FROM empty_table; +---- +0 +0 +0 +0 + +statement ok +DROP TABLE t_stringview; + +statement ok +DROP TABLE empty_table; \ No newline at end of file diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index cbcec710e267..406a93d73b3b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1866,6 +1866,41 @@ SELECT regexp_like('aBc', '(b|d)', 'i'); Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +### `regexp_instr` + +Returns the position in a string where the specified occurrence of a [regular expression](https://docs.rs/regex/latest/regex/#syntax) is located. + +```sql +regexp_instr(string, pattern [, start [, N [, endoption [, flags [, subexpr ]]]]]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. +- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function. +- **endoption**: Optional. If 0, returns the starting position of the match (default). If 1, returns the ending position of the match. Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? +- **subexpr**: Optional. An integer indicating which parenthesized subexpression's position to return. Defaults to 0, which means the whole match + +#### Example + +```sql +> select regexp_instr('ABCDEF', 'c(.)(..)', 1, 1, 0, 'i', 2); ++---------------------------------------------------------------+ +| regexp_instr(Utf8('ABCDEF'), Utf8('c(.)(..)'), Int64(1), | +| Int64(1), Int64(0),, 'i', Int64(2),) | ++---------------------------------------------------------------+ +| 5 | ++---------------------------------------------------------------+ +``` + ### `regexp_match` Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. From 4fb0cca3d7e6c91a1adf3e4cd947de8820d63613 Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Sat, 3 May 2025 02:13:06 +0530 Subject: [PATCH 02/11] linting and typo addressed in bench --- datafusion/functions/benches/regx.rs | 2 +- datafusion/functions/src/regex/regexpinstr.rs | 75 ++++++++----------- 2 files changed, 32 insertions(+), 45 deletions(-) diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index e7cba5c5e6b6..b003fd22acbb 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -23,7 +23,7 @@ use arrow::compute::cast; use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_functions::regex::regexpcount::regexp_count_func; -use datafusion_functions::regex::regexpcount::regexp_instr; +use datafusion_functions::regex::regexpinstr::regexp_instr_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 8fb6c585b6f7..31552a45feb4 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -346,14 +346,13 @@ enum ScalarOrArray { impl ScalarOrArray { fn iter(&self, len: usize) -> Box + '_> { match self { - ScalarOrArray::Scalar(val) => { - Box::new(std::iter::repeat(val.clone()).take(len)) - } + ScalarOrArray::Scalar(val) => Box::new(std::iter::repeat_n(val.clone(), len)), ScalarOrArray::Array(arr) => Box::new(arr.iter().cloned()), } } } +#[allow(clippy::too_many_arguments)] pub fn regexp_instr_inner<'a, S>( values: S, regex_array: S, @@ -390,13 +389,10 @@ where ScalarOrArray::Array(start_vec) } + } else if len == 1 { + ScalarOrArray::Scalar(1) } else { - if len == 1 { - ScalarOrArray::Scalar(1) - } else { - ScalarOrArray::Array(vec![1; len]) - } - // Default start = 1 + ScalarOrArray::Array(vec![1; len]) }; let nth_input = if let Some(nth) = nth_array { @@ -408,14 +404,12 @@ where .collect(); ScalarOrArray::Array(nth_vec) } - } else { - if len == 1 { - ScalarOrArray::Scalar(1) - } - // Default nth = 0 - else { - ScalarOrArray::Array(vec![1; len]) - } + } else if len == 1 { + ScalarOrArray::Scalar(1) + } + // Default nth = 0 + else { + ScalarOrArray::Array(vec![1; len]) }; let endoption_input = if let Some(endoption) = endoption_array { @@ -433,14 +427,12 @@ where .collect(); ScalarOrArray::Array(endoption_vec) } - } else { - if len == 1 { - ScalarOrArray::Scalar(0) - } - // Default nth = 0 - else { - ScalarOrArray::Array(vec![0; len]) - } // Default endoption = 0 + } else if len == 1 { + ScalarOrArray::Scalar(0) + } + // Default nth = 0 + else { + ScalarOrArray::Array(vec![0; len]) }; let flags_input = if let Some(ref flags) = flags_array { @@ -450,14 +442,12 @@ where let flags_vec: Vec<&str> = flags.iter().map(|v| v.unwrap_or("")).collect(); ScalarOrArray::Array(flags_vec) } - } else { - if len == 1 { - ScalarOrArray::Scalar("") - } - // Default flags = "" - else { - ScalarOrArray::Array(vec![""; len]) - } // Default flags = "" + } else if len == 1 { + ScalarOrArray::Scalar("") + } + // Default flags = "" + else { + ScalarOrArray::Array(vec![""; len]) }; let subexp_input = if let Some(subexp) = subexp_array { @@ -475,14 +465,12 @@ where .collect(); ScalarOrArray::Array(subexp_vec) } - } else { - if len == 1 { - ScalarOrArray::Scalar(0) - } - // Default subexp = 0 - else { - ScalarOrArray::Array(vec![0; len]) - } + } else if len == 1 { + ScalarOrArray::Scalar(0) + } + // Default subexp = 0 + else { + ScalarOrArray::Array(vec![0; len]) }; let mut regex_cache = HashMap::new(); @@ -501,9 +489,9 @@ where return Ok(0); } - let pattern = compile_and_cache_regex(®ex, Some(flags), &mut regex_cache)?; + let pattern = compile_and_cache_regex(regex, Some(flags), &mut regex_cache)?; - get_index(value, &pattern, start, nth, endoption, subexp) + get_index(value, pattern, start, nth, endoption, subexp) }) .collect(); @@ -991,5 +979,4 @@ mod tests { .unwrap(); assert_eq!(re.as_ref(), &expected); } - } From cff8d42d6ec3c1bbe8a6ef5ed2270e1d38308a35 Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Sun, 4 May 2025 16:28:17 +0530 Subject: [PATCH 03/11] prettier formatting --- docs/source/user-guide/sql/scalar_functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 406a93d73b3b..8ec3f6d972b0 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1880,7 +1880,7 @@ regexp_instr(string, pattern [, start [, N [, endoption [, flags [, subexpr ]]]] - **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. - **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function. -- **endoption**: Optional. If 0, returns the starting position of the match (default). If 1, returns the ending position of the match. Can be a constant, column, or function. +- **endoption**: Optional. If 0, returns the starting position of the match (default). If 1, returns the ending position of the match. Can be a constant, column, or function. - **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case - **m**: multi-line mode: ^ and $ match begin/end of line From 2e70881cdd7ebed118c40d052327a7c752865371 Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Thu, 8 May 2025 12:15:31 +0530 Subject: [PATCH 04/11] scalar_functions_formatting --- .../source/user-guide/sql/scalar_functions.md | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 8ec3f6d972b0..096458a86d97 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1793,6 +1793,7 @@ regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) The following regular expression functions are supported: - [regexp_count](#regexp_count) +- [regexp_instr](#regexp_instr) - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) @@ -1828,79 +1829,78 @@ regexp_count(str, regexp[, start, flags]) +---------------------------------------------------------------+ ``` -### `regexp_like` +### `regexp_instr` -Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise. +Returns the position in a string where the specified occurrence of a POSIX regular expression is located. ```sql -regexp_like(str, regexp[, flags]) +regexp_instr(str, regexp[, start, N, endoption, flags]) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1 +- **N**: - **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function. +- **endoption**: - **endoption**: Optional. If 0, returns the starting position of the match (default). If 1, returns the ending position of the match. Can be a constant, column, or function. - **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case - **m**: multi-line mode: ^ and $ match begin/end of line - **s**: allow . to match \n - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - **U**: swap the meaning of x* and x*? +- **subexpr**: Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match. #### Example ```sql -select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); -+--------------------------------------------------------+ -| regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | -+--------------------------------------------------------+ -| true | -+--------------------------------------------------------+ -SELECT regexp_like('aBc', '(b|d)', 'i'); -+--------------------------------------------------+ -| regexp_like(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | -+--------------------------------------------------+ -| true | -+--------------------------------------------------+ +> SELECT regexp_instr('ABCDEF', 'c(.)(..)', 1, 1, 0, 'i', 2); ++---------------------------------------------------------------+ +| regexp_instr() | ++---------------------------------------------------------------+ +| 5 | ++---------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - -### `regexp_instr` +### `regexp_like` -Returns the position in a string where the specified occurrence of a [regular expression](https://docs.rs/regex/latest/regex/#syntax) is located. +Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise. ```sql -regexp_instr(string, pattern [, start [, N [, endoption [, flags [, subexpr ]]]]]) +regexp_like(str, regexp[, flags]) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. -- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function. -- **endoption**: Optional. If 0, returns the starting position of the match (default). If 1, returns the ending position of the match. Can be a constant, column, or function. - **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case - **m**: multi-line mode: ^ and $ match begin/end of line - **s**: allow . to match \n - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - **U**: swap the meaning of x* and x*? -- **subexpr**: Optional. An integer indicating which parenthesized subexpression's position to return. Defaults to 0, which means the whole match #### Example ```sql -> select regexp_instr('ABCDEF', 'c(.)(..)', 1, 1, 0, 'i', 2); -+---------------------------------------------------------------+ -| regexp_instr(Utf8('ABCDEF'), Utf8('c(.)(..)'), Int64(1), | -| Int64(1), Int64(0),, 'i', Int64(2),) | -+---------------------------------------------------------------+ -| 5 | -+---------------------------------------------------------------+ +select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); ++--------------------------------------------------------+ +| regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | ++--------------------------------------------------------+ +| true | ++--------------------------------------------------------+ +SELECT regexp_like('aBc', '(b|d)', 'i'); ++--------------------------------------------------+ +| regexp_like(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | ++--------------------------------------------------+ +| true | ++--------------------------------------------------+ ``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) + ### `regexp_match` Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. From c3131ed1724e2b1f7925a834d4e3fce64e796248 Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Tue, 20 May 2025 03:21:25 +0530 Subject: [PATCH 05/11] linting format macros --- datafusion/functions/src/regex/regexpinstr.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 31552a45feb4..ed75449d6085 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -525,14 +525,14 @@ fn compile_regex(regex: &str, flags: Option<&str>) -> Result "regexp_instr() does not support global flag".to_string(), )); } - format!("(?{}){}", flags, regex) + format!("(?{flags}){regex}") } }; Regex::new(&pattern).map_err(|_| { ArrowError::ComputeError(format!( - "Regular expression did not compile: {}", - pattern + "Regular expression did not compile: {pattern}" + )) }) } From eeeac8d864f488475f20e5004b510384cc56fe3b Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Tue, 20 May 2025 11:08:54 +0530 Subject: [PATCH 06/11] formatting --- datafusion/functions/src/regex/regexpinstr.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index ed75449d6085..74698709a3ee 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -530,10 +530,7 @@ fn compile_regex(regex: &str, flags: Option<&str>) -> Result }; Regex::new(&pattern).map_err(|_| { - ArrowError::ComputeError(format!( - "Regular expression did not compile: {pattern}" - - )) + ArrowError::ComputeError(format!("Regular expression did not compile: {pattern}")) }) } From 89e23153aa8085be1423137a778f6c39b3631c8c Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Thu, 12 Jun 2025 02:23:38 +0530 Subject: [PATCH 07/11] address comments to PR --- datafusion/functions/src/regex/regexpcount.rs | 4 +- datafusion/functions/src/regex/regexpinstr.rs | 340 +++++------------- .../test_files/regexp/regexp_instr.slt | 31 +- 3 files changed, 105 insertions(+), 270 deletions(-) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8f53bf8eb158..95146566de25 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -550,7 +550,7 @@ where } } -fn compile_and_cache_regex<'strings, 'cache>( +pub fn compile_and_cache_regex<'strings, 'cache>( regex: &'strings str, flags: Option<&'strings str>, regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, @@ -568,7 +568,7 @@ where Ok(result) } -fn compile_regex(regex: &str, flags: Option<&str>) -> Result { +pub fn compile_regex(regex: &str, flags: Option<&str>) -> Result { let pattern = match flags { None | Some("") => regex.to_string(), Some(flags) => { diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 74698709a3ee..e921e9001477 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -16,16 +16,20 @@ // under the License. use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType}; -use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{DataType, Int64Type, Field}; use arrow::datatypes::{ DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, }; +use arrow::ipc::Null; +use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::error::ArrowError; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::Values; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, TypeSignature::Uniform, Volatility, }; +use super::regexpcount::{compile_and_cache_regex, compile_regex}; use datafusion_macros::user_doc; use itertools::izip; use regex::Regex; @@ -36,7 +40,7 @@ use std::sync::Arc; #[user_doc( doc_section(label = "Regular Expression Functions"), description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.", - syntax_example = "regexp_instr(str, regexp[, start, N, endoption, flags])", + syntax_example = "regexp_instr(str, regexp[, start[, N[, flags]]])", sql_example = r#"```sql > SELECT regexp_instr('ABCDEF', 'c(.)(..)', 1, 1, 0, 'i', 2); +---------------------------------------------------------------+ @@ -55,10 +59,6 @@ use std::sync::Arc; name = "N", description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function." ), - argument( - name = "endoption", - description = "- **endoption**: Optional. If 0, returns the starting position of the match (default). If 1, returns the ending position of the match. Can be a constant, column, or function." - ), argument( name = "flags", description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: @@ -96,13 +96,9 @@ impl RegexpInstrFunc { Exact(vec![Utf8View, Utf8View, Int64, Int64]), Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), Exact(vec![Utf8, Utf8, Int64, Int64]), - Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]), - Exact(vec![Utf8View, Utf8View, Int64, Utf8View]), - Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), - Exact(vec![Utf8, Utf8, Int64, Utf8]), - Exact(vec![Utf8, Utf8, Int64, Int64, Int64, Utf8]), - Exact(vec![Utf8, Utf8, Int64, Int64, Int64, Utf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]), + Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]), ], Volatility::Immutable, ), @@ -164,8 +160,8 @@ impl ScalarUDFImpl for RegexpInstrFunc { pub fn regexp_instr_func(args: &[ArrayRef]) -> Result { let args_len = args.len(); - if !(2..=7).contains(&args_len) { - return exec_err!("regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 7."); + if !(2..=6).contains(&args_len) { + return exec_err!("regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6."); } let values = &args[0]; @@ -185,7 +181,6 @@ pub fn regexp_instr_func(args: &[ArrayRef]) -> Result { if args_len > 3 { Some(&args[3]) } else { None }, if args_len > 4 { Some(&args[4]) } else { None }, if args_len > 5 { Some(&args[5]) } else { None }, - if args_len > 6 { Some(&args[6]) } else { None }, ) .map_err(|e| e.into()) } @@ -213,7 +208,6 @@ pub fn regexp_instr( regex_array: &dyn Datum, start_array: Option<&dyn Datum>, nth_array: Option<&dyn Datum>, - endoption_array: Option<&dyn Datum>, flags_array: Option<&dyn Datum>, subexpr_array: Option<&dyn Datum>, ) -> Result { @@ -226,11 +220,6 @@ pub fn regexp_instr( let (nth, is_nth_scalar) = nth.get(); (Some(nth), is_nth_scalar) }); - let (endoption_array, is_endoption_scalar) = - endoption_array.map_or((None, true), |endoption| { - let (endoption, is_endoption_scalar) = endoption.get(); - (Some(endoption), is_endoption_scalar) - }); let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| { let (flags, is_flags_scalar) = flags.get(); (Some(flags), is_flags_scalar) @@ -250,8 +239,6 @@ pub fn regexp_instr( is_start_scalar, nth_array.map(|nth| nth.as_primitive::()), is_nth_scalar, - endoption_array.map(|endoption| endoption.as_primitive::()), - is_endoption_scalar, None, is_flags_scalar, subexpr_array.map(|subexpr| subexpr.as_primitive::()), @@ -265,8 +252,6 @@ pub fn regexp_instr( is_start_scalar, nth_array.map(|nth| nth.as_primitive::()), is_nth_scalar, - endoption_array.map(|endoption| endoption.as_primitive::()), - is_endoption_scalar, Some(flags_array.as_string::()), is_flags_scalar, subexpr_array.map(|subexpr| subexpr.as_primitive::()), @@ -280,8 +265,6 @@ pub fn regexp_instr( is_start_scalar, nth_array.map(|nth| nth.as_primitive::()), is_nth_scalar, - endoption_array.map(|endoption| endoption.as_primitive::()), - is_endoption_scalar, None, is_flags_scalar, subexpr_array.map(|subexpr| subexpr.as_primitive::()), @@ -295,8 +278,6 @@ pub fn regexp_instr( is_start_scalar, nth_array.map(|nth| nth.as_primitive::()), is_nth_scalar, - endoption_array.map(|endoption| endoption.as_primitive::()), - is_endoption_scalar, Some(flags_array.as_string::()), is_flags_scalar, subexpr_array.map(|subexpr| subexpr.as_primitive::()), @@ -310,8 +291,6 @@ pub fn regexp_instr( is_start_scalar, nth_array.map(|nth| nth.as_primitive::()), is_nth_scalar, - endoption_array.map(|endoption| endoption.as_primitive::()), - is_endoption_scalar, None, is_flags_scalar, subexpr_array.map(|subexpr| subexpr.as_primitive::()), @@ -325,8 +304,6 @@ pub fn regexp_instr( is_start_scalar, nth_array.map(|nth| nth.as_primitive::()), is_nth_scalar, - endoption_array.map(|endoption| endoption.as_primitive::()), - is_endoption_scalar, Some(flags_array.as_string_view()), is_flags_scalar, subexpr_array.map(|subexpr| subexpr.as_primitive::()), @@ -361,8 +338,6 @@ pub fn regexp_instr_inner<'a, S>( is_start_scalar: bool, nth_array: Option<&Int64Array>, is_nth_scalar: bool, - endoption_array: Option<&Int64Array>, - is_endoption_scalar: bool, flags_array: Option, is_flags_scalar: bool, subexp_array: Option<&Int64Array>, @@ -372,10 +347,11 @@ where S: StringArrayType<'a>, { let len = values.len(); + let regex_input = if is_regex_scalar || regex_array.len() == 1 { - ScalarOrArray::Scalar(regex_array.value(0)) + ScalarOrArray::Scalar(Some(regex_array.value(0))) } else { - let regex_vec: Vec<&str> = regex_array.iter().map(|v| v.unwrap_or("")).collect(); + let regex_vec: Vec> = regex_array.iter().map(|v| v).collect(); ScalarOrArray::Array(regex_vec) }; @@ -412,28 +388,6 @@ where ScalarOrArray::Array(vec![1; len]) }; - let endoption_input = if let Some(endoption) = endoption_array { - if is_endoption_scalar || endoption.len() == 1 { - ScalarOrArray::Scalar(endoption.value(0)) - } else { - let endoption_vec: Vec = (0..endoption.len()) - .map(|i| { - if endoption.is_null(i) { - 0 - } else { - endoption.value(i) - } - }) // handle nulls as 0 - .collect(); - ScalarOrArray::Array(endoption_vec) - } - } else if len == 1 { - ScalarOrArray::Scalar(0) - } - // Default nth = 0 - else { - ScalarOrArray::Array(vec![0; len]) - }; let flags_input = if let Some(ref flags) = flags_array { if is_flags_scalar || flags.len() == 1 { @@ -475,125 +429,112 @@ where let mut regex_cache = HashMap::new(); - let result: Result, ArrowError> = izip!( + let result: Result>, ArrowError> = izip!( values.iter(), regex_input.iter(len), start_input.iter(len), nth_input.iter(len), - endoption_input.iter(len), flags_input.iter(len), subexp_input.iter(len) ) - .map(|(value, regex, start, nth, endoption, flags, subexp)| { - if regex.is_empty() { - return Ok(0); - } - - let pattern = compile_and_cache_regex(regex, Some(flags), &mut regex_cache)?; + .map(|(value, regex, start, nth, flags, subexp)| { + // let regex = match regex { + // "" => return Ok(Some(0)), + // regex => regex, + // }; + match regex { + None => return Ok(None), + Some("") => return Ok(None), + Some(regex) => return get_index(value, regex, start, nth, subexp, Some(flags), &mut regex_cache), + }; - get_index(value, pattern, start, nth, endoption, subexp) - }) - .collect(); +}).collect(); Ok(Arc::new(Int64Array::from(result?))) } -fn compile_and_cache_regex<'strings, 'cache>( - regex: &'strings str, - flags: Option<&'strings str>, - regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, -) -> Result<&'cache Regex, ArrowError> -where - 'strings: 'cache, -{ - let result = match regex_cache.entry((regex, flags)) { - Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), - Entry::Vacant(vacant_entry) => { - let compiled = compile_regex(regex, flags)?; - vacant_entry.insert(compiled) - } - }; - Ok(result) -} -fn compile_regex(regex: &str, flags: Option<&str>) -> Result { - let pattern = match flags { - None | Some("") => regex.to_string(), - Some(flags) => { - if flags.contains("g") { - return Err(ArrowError::ComputeError( - "regexp_instr() does not support global flag".to_string(), - )); - } - format!("(?{flags}){regex}") - } - }; - - Regex::new(&pattern).map_err(|_| { - ArrowError::ComputeError(format!("Regular expression did not compile: {pattern}")) - }) -} - -fn get_index( +fn get_index<'strings, 'cache>( value: Option<&str>, - pattern: &Regex, + pattern: &'strings str, start: i64, n: i64, - endoption: i64, subexpr: i64, -) -> Result { + flags: Option<&'strings str>, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex> +) -> Result, ArrowError> where +'strings: 'cache, +{ + let value = match value { - None | Some("") => return Ok(0), + None => return Ok(None), + Some("") => return Ok(Some(0)), Some(value) => value, }; - // let start = start.unwrap_or(1); + // let pattern = match pattern { + // None => return Ok(None), + // Some("") => return Ok(Some(0)), + // Some(pattern) => pattern, + // }; + + let pattern = compile_and_cache_regex(pattern, flags, regex_cache)?; if start < 1 { return Err(ArrowError::ComputeError( "regexp_instr() requires start to be 1-based".to_string(), )); } - // let n = n.unwrap_or(1); // Default to finding the first match if n < 1 { return Err(ArrowError::ComputeError( "N must be 1 or greater".to_string(), )); } - let find_slice = value.chars().skip(start as usize - 1).collect::(); - let matches: Vec<_> = pattern.find_iter(&find_slice).collect(); - - let mut result_index = 0; - if matches.len() < n as usize { - return Ok(result_index); // Return 0 if the N-th match was not found + // --- Simplified byte_start_offset calculation --- + let total_chars = value.chars().count() as i64; + let byte_start_offset = if start > total_chars { + // If start is beyond the total characters, it means we start searching + // after the string effectively. No matches possible. + return Ok(Some(0)); } else { - let nth_match = matches.get((n - 1) as usize).ok_or_else(|| { - ArrowError::ComputeError("N-th match not found".to_string()) - })?; - - let match_start = nth_match.start() as i64 + start; - let match_end = nth_match.end() as i64 + start; - - result_index = match endoption { - 1 => match_end, // Return end position of match - _ => match_start, // Default: Return start position - }; - } - // Find the N-th match (1-based index) + // Get the byte offset for the (start - 1)-th character (0-based) + value + .char_indices() + .nth((start - 1) as usize) + .map(|(idx, _)| idx) + .unwrap_or(0) // Should not happen if start is valid and <= total_chars + }; + // --- End simplified calculation --- - // Handle subexpression capturing (if requested) + let search_slice = &value[byte_start_offset..]; + // Handle subexpression capturing first, as it takes precedence if subexpr > 0 { - if let Some(captures) = pattern.captures(&find_slice) { + if let Some(captures) = pattern.captures(search_slice) { if let Some(matched) = captures.get(subexpr as usize) { - return Ok(matched.start() as i64 + start); + // Convert byte offset relative to search_slice back to 1-based character offset + // relative to the original `value` string. + let start_char_offset = value[..byte_start_offset + matched.start()] + .chars() + .count() as i64 + + 1; + return Ok(Some(start_char_offset)); } } - return Ok(0); // Return 0 if the subexpression was not found + return Ok(Some(0)); // Return 0 if the subexpression was not found } - Ok(result_index) + // Use nth to get the N-th match (n is 1-based, nth is 0-based) + if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) { + // Convert byte offset relative to search_slice back to 1-based character offset + // relative to the original `value` string. + let match_start_byte_offset = byte_start_offset + mat.start(); + let match_start_char_offset = value[..match_start_byte_offset].chars().count() as i64 + 1; + Ok(Some(match_start_char_offset)) + } else { + Ok(Some(0)) // Return 0 if the N-th match was not found + } } #[cfg(test)] @@ -608,7 +549,6 @@ mod tests { test_case_sensitive_regexp_instr_scalar(); test_case_sensitive_regexp_instr_scalar_start(); test_case_sensitive_regexp_instr_scalar_nth(); - test_case_sensitive_regexp_instr_scalar_endoption(); test_case_sensitive_regexp_instr_array::>(); test_case_sensitive_regexp_instr_array::>(); @@ -621,10 +561,6 @@ mod tests { test_case_sensitive_regexp_instr_array_nth::>(); test_case_sensitive_regexp_instr_array_nth::>(); test_case_sensitive_regexp_instr_array_nth::(); - - test_case_sensitive_regexp_instr_array_endoption::>(); - test_case_sensitive_regexp_instr_array_endoption::>(); - test_case_sensitive_regexp_instr_array_endoption::(); } fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result { @@ -654,13 +590,16 @@ mod tests { "abcdefg", "xyz123xyz", "no match here", - "", "abc", - "", + "ДатаФусион数据融合📊🔥", ]; - let regex = ["o", "d", "123", "z", "gg", "", ""]; + let regex = ["o", "d", "123", "z", "gg", "📊"]; + + let expected: Vec = vec![5, 4, 4, 0, 0, 15]; - let expected: Vec = vec![5, 4, 4, 0, 0, 0, 0]; + // let values = [""]; + // let regex = [""]; + // let expected: Vec = vec![0]; izip!(values.iter(), regex.iter()) .enumerate() @@ -820,86 +759,6 @@ mod tests { }); } - fn test_case_sensitive_regexp_instr_scalar_endoption() { - let values = ["abcdefg", "abcdefg"]; - let regex = ["cd", "cd"]; - let start = [1, 1]; - let nth = [1, 1]; - let endoption = [0, 1]; - let expected: Vec = vec![3, 5]; - - izip!( - values.iter(), - regex.iter(), - start.iter(), - nth.iter(), - endoption.iter() - ) - .enumerate() - .for_each(|(pos, (&v, &r, &s, &n, &e))| { - // utf8 - let v_sv = ScalarValue::Utf8(Some(v.to_string())); - let regex_sv = ScalarValue::Utf8(Some(r.to_string())); - let start_sv = ScalarValue::Int64(Some(s)); - let nth_sv = ScalarValue::Int64(Some(n)); - let endoption_sv = ScalarValue::Int64(Some(e)); - let expected = expected.get(pos).cloned(); - let re = regexp_instr_with_scalar_values(&[ - v_sv, - regex_sv, - start_sv.clone(), - nth_sv.clone(), - endoption_sv.clone(), - ]); - match re { - Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { - assert_eq!(v, expected, "regexp_instr scalar test failed"); - } - _ => panic!("Unexpected result"), - } - - // largeutf8 - let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); - let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); - let start_sv = ScalarValue::Int64(Some(s)); - let nth_sv = ScalarValue::Int64(Some(n)); - let endoption_sv = ScalarValue::Int64(Some(e)); - let re = regexp_instr_with_scalar_values(&[ - v_sv, - regex_sv, - start_sv.clone(), - nth_sv.clone(), - endoption_sv.clone(), - ]); - match re { - Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { - assert_eq!(v, expected, "regexp_instr scalar test failed"); - } - _ => panic!("Unexpected result"), - } - - // utf8view - let v_sv = ScalarValue::Utf8View(Some(v.to_string())); - let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); - let start_sv = ScalarValue::Int64(Some(s)); - let nth_sv = ScalarValue::Int64(Some(n)); - let endoption_sv = ScalarValue::Int64(Some(e)); - let re = regexp_instr_with_scalar_values(&[ - v_sv, - regex_sv, - start_sv.clone(), - nth_sv.clone(), - endoption_sv.clone(), - ]); - match re { - Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { - assert_eq!(v, expected, "regexp_instr scalar test failed"); - } - _ => panic!("Unexpected result"), - } - }); - } - fn test_case_sensitive_regexp_instr_array() where A: From> + Array + 'static, @@ -910,12 +769,10 @@ mod tests { "xyz123xyz", "no match here", "", - "abc", - "", ]); - let regex = A::from(vec!["o", "d", "123", "z", "gg", "", ""]); + let regex = A::from(vec!["o", "d", "123", "z", "gg"]); - let expected = Int64Array::from(vec![5, 4, 4, 0, 0, 0, 0]); + let expected = Int64Array::from(vec![5, 4, 4, 0, 0]); let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -953,27 +810,4 @@ mod tests { .unwrap(); assert_eq!(re.as_ref(), &expected); } - - fn test_case_sensitive_regexp_instr_array_endoption() - where - A: From> + Array + 'static, - { - let values = A::from(vec!["abcdefg", "abcdefg"]); - let regex = A::from(vec!["cd", "cd"]); - let start = Int64Array::from(vec![1, 1]); - let nth = Int64Array::from(vec![1, 1]); - let endoption = Int64Array::from(vec![0, 1]); - - let expected = Int64Array::from(vec![3, 5]); - - let re = regexp_instr_func(&[ - Arc::new(values), - Arc::new(regex), - Arc::new(start), - Arc::new(nth), - Arc::new(endoption), - ]) - .unwrap(); - assert_eq!(re.as_ref(), &expected); - } } diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt b/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt index 021db3dd74a2..c651422142d4 100644 --- a/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt +++ b/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt @@ -39,12 +39,12 @@ SELECT regexp_instr('123123123123', '123', 33); 0 query I -SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, 0, ''); +SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, ''); ---- 0 query I -SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, 0, 'i'); +SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, 'i'); ---- 4 @@ -59,7 +59,7 @@ SELECT regexp_instr('123123123123', '123', -3); query I SELECT regexp_instr(str, pattern) FROM regexp_test_data; ---- -0 +NULL 1 1 0 @@ -75,7 +75,7 @@ SELECT regexp_instr(str, pattern) FROM regexp_test_data; query I SELECT regexp_instr(str, pattern, start) FROM regexp_test_data; ---- -0 +NULL 1 1 0 @@ -100,7 +100,7 @@ FROM regexp_test_data; query I SELECT regexp_instr(str, pattern, start) FROM t_stringview; ---- -0 +NULL 1 1 0 @@ -120,7 +120,7 @@ SELECT regexp_instr( arrow_cast(start, 'Int32') ) FROM t_stringview; ---- -0 +NULL 1 1 0 @@ -136,22 +136,23 @@ SELECT regexp_instr( query I SELECT regexp_instr(NULL, NULL); ---- -0 +NULL query I SELECT regexp_instr(NULL, 'a'); ---- -0 +NULL query I SELECT regexp_instr('a', NULL); ---- -0 +NULL query I -SELECT regexp_instr(NULL, NULL, NULL); +SELECT regexp_instr('😀abcdef', 'abc'); ---- -0 +2 + statement ok CREATE TABLE empty_table (str varchar, pattern varchar, start int); @@ -170,10 +171,10 @@ INSERT INTO empty_table VALUES query I SELECT regexp_instr(str, pattern, start) FROM empty_table; ---- -0 -0 -0 -0 +NULL +NULL +NULL +NULL statement ok DROP TABLE t_stringview; From fad6e7bb01c3539df0ca2812620b76b7cfd22394 Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Thu, 12 Jun 2025 03:25:23 +0530 Subject: [PATCH 08/11] formatting --- datafusion/functions/src/regex/regexpinstr.rs | 63 +++++++++---------- .../source/user-guide/sql/scalar_functions.md | 9 ++- 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index e921e9001477..0e66662129a4 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -15,25 +15,22 @@ // specific language governing permissions and limitations // under the License. +use super::regexpcount::compile_and_cache_regex; + use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType}; -use arrow::datatypes::{DataType, Int64Type, Field}; +use arrow::datatypes::{DataType, Int64Type}; use arrow::datatypes::{ DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, }; -use arrow::ipc::Null; -use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::error::ArrowError; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; -use datafusion_expr::Values; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, TypeSignature::Uniform, Volatility, }; -use super::regexpcount::{compile_and_cache_regex, compile_regex}; use datafusion_macros::user_doc; use itertools::izip; use regex::Regex; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::Arc; @@ -42,11 +39,11 @@ use std::sync::Arc; description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.", syntax_example = "regexp_instr(str, regexp[, start[, N[, flags]]])", sql_example = r#"```sql -> SELECT regexp_instr('ABCDEF', 'c(.)(..)', 1, 1, 0, 'i', 2); +> SELECT regexp_instr('ABCDEF', 'c(.)(..)'); +---------------------------------------------------------------+ -| regexp_instr() | +| regexp_instr(Utf8("ABCDEF"),Utf8("c(.)(..)")) | +---------------------------------------------------------------+ -| 5 | +| 2 | +---------------------------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), @@ -347,7 +344,7 @@ where S: StringArrayType<'a>, { let len = values.len(); - + let regex_input = if is_regex_scalar || regex_array.len() == 1 { ScalarOrArray::Scalar(Some(regex_array.value(0))) } else { @@ -388,7 +385,6 @@ where ScalarOrArray::Array(vec![1; len]) }; - let flags_input = if let Some(ref flags) = flags_array { if is_flags_scalar || flags.len() == 1 { ScalarOrArray::Scalar(flags.value(0)) @@ -438,22 +434,27 @@ where subexp_input.iter(len) ) .map(|(value, regex, start, nth, flags, subexp)| { - // let regex = match regex { - // "" => return Ok(Some(0)), - // regex => regex, - // }; match regex { None => return Ok(None), Some("") => return Ok(None), - Some(regex) => return get_index(value, regex, start, nth, subexp, Some(flags), &mut regex_cache), + Some(regex) => { + return get_index( + value, + regex, + start, + nth, + subexp, + Some(flags), + &mut regex_cache, + ) + } }; - -}).collect(); + }) + .collect(); Ok(Arc::new(Int64Array::from(result?))) } - fn get_index<'strings, 'cache>( value: Option<&str>, pattern: &'strings str, @@ -461,23 +462,17 @@ fn get_index<'strings, 'cache>( n: i64, subexpr: i64, flags: Option<&'strings str>, - regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex> -) -> Result, ArrowError> where -'strings: 'cache, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, +) -> Result, ArrowError> +where + 'strings: 'cache, { - let value = match value { None => return Ok(None), Some("") => return Ok(Some(0)), Some(value) => value, }; - // let pattern = match pattern { - // None => return Ok(None), - // Some("") => return Ok(Some(0)), - // Some(pattern) => pattern, - // }; - let pattern = compile_and_cache_regex(pattern, flags, regex_cache)?; if start < 1 { return Err(ArrowError::ComputeError( @@ -515,10 +510,9 @@ fn get_index<'strings, 'cache>( if let Some(matched) = captures.get(subexpr as usize) { // Convert byte offset relative to search_slice back to 1-based character offset // relative to the original `value` string. - let start_char_offset = value[..byte_start_offset + matched.start()] - .chars() - .count() as i64 - + 1; + let start_char_offset = + value[..byte_start_offset + matched.start()].chars().count() as i64 + + 1; return Ok(Some(start_char_offset)); } } @@ -530,7 +524,8 @@ fn get_index<'strings, 'cache>( // Convert byte offset relative to search_slice back to 1-based character offset // relative to the original `value` string. let match_start_byte_offset = byte_start_offset + mat.start(); - let match_start_char_offset = value[..match_start_byte_offset].chars().count() as i64 + 1; + let match_start_char_offset = + value[..match_start_byte_offset].chars().count() as i64 + 1; Ok(Some(match_start_char_offset)) } else { Ok(Some(0)) // Return 0 if the N-th match was not found diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 096458a86d97..8ac736e1e33d 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1834,7 +1834,7 @@ regexp_count(str, regexp[, start, flags]) Returns the position in a string where the specified occurrence of a POSIX regular expression is located. ```sql -regexp_instr(str, regexp[, start, N, endoption, flags]) +regexp_instr(str, regexp[, start[, N[, flags]]]) ``` #### Arguments @@ -1843,7 +1843,6 @@ regexp_instr(str, regexp[, start, N, endoption, flags]) - **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. - **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1 - **N**: - **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function. -- **endoption**: - **endoption**: Optional. If 0, returns the starting position of the match (default). If 1, returns the ending position of the match. Can be a constant, column, or function. - **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case - **m**: multi-line mode: ^ and $ match begin/end of line @@ -1855,11 +1854,11 @@ regexp_instr(str, regexp[, start, N, endoption, flags]) #### Example ```sql -> SELECT regexp_instr('ABCDEF', 'c(.)(..)', 1, 1, 0, 'i', 2); +> SELECT regexp_instr('ABCDEF', 'c(.)(..)'); +---------------------------------------------------------------+ -| regexp_instr() | +| regexp_instr(Utf8("ABCDEF"),Utf8("c(.)(..)")) | +---------------------------------------------------------------+ -| 5 | +| 2 | +---------------------------------------------------------------+ ``` From 93c35c8cc3160d1ac040e207df6a3847ab326878 Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Thu, 12 Jun 2025 12:26:24 +0530 Subject: [PATCH 09/11] clippy --- datafusion/functions/src/regex/regexpinstr.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 0e66662129a4..5ecf367e3b52 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -348,7 +348,7 @@ where let regex_input = if is_regex_scalar || regex_array.len() == 1 { ScalarOrArray::Scalar(Some(regex_array.value(0))) } else { - let regex_vec: Vec> = regex_array.iter().map(|v| v).collect(); + let regex_vec: Vec> = regex_array.iter().collect(); ScalarOrArray::Array(regex_vec) }; @@ -435,10 +435,10 @@ where ) .map(|(value, regex, start, nth, flags, subexp)| { match regex { - None => return Ok(None), - Some("") => return Ok(None), + None => Ok(None), + Some("") => Ok(None), Some(regex) => { - return get_index( + get_index( value, regex, start, @@ -448,7 +448,7 @@ where &mut regex_cache, ) } - }; + } }) .collect(); From 990a6537cb62531313f04f9d85ba3d5a1c57bf39 Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Thu, 12 Jun 2025 18:48:21 +0530 Subject: [PATCH 10/11] fmt --- datafusion/functions/src/regex/regexpinstr.rs | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 5ecf367e3b52..1f9211fdde5d 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -433,22 +433,18 @@ where flags_input.iter(len), subexp_input.iter(len) ) - .map(|(value, regex, start, nth, flags, subexp)| { - match regex { - None => Ok(None), - Some("") => Ok(None), - Some(regex) => { - get_index( - value, - regex, - start, - nth, - subexp, - Some(flags), - &mut regex_cache, - ) - } - } + .map(|(value, regex, start, nth, flags, subexp)| match regex { + None => Ok(None), + Some("") => Ok(None), + Some(regex) => get_index( + value, + regex, + start, + nth, + subexp, + Some(flags), + &mut regex_cache, + ), }) .collect(); From 1120043c0c9e85f1f11d557669071e1b610ad525 Mon Sep 17 00:00:00 2001 From: Nirnay Roy Date: Wed, 18 Jun 2025 17:32:41 +0530 Subject: [PATCH 11/11] address docs typo --- datafusion/functions/src/regex/mod.rs | 41 ++++++++++++++++++- datafusion/functions/src/regex/regexpcount.rs | 39 +----------------- datafusion/functions/src/regex/regexpinstr.rs | 10 ++--- .../source/user-guide/sql/scalar_functions.md | 6 +-- 4 files changed, 50 insertions(+), 46 deletions(-) diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index b0278ac9df61..6b3919bd2b75 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,8 +17,11 @@ //! "regex" DataFusion functions +use arrow::error::ArrowError; +use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::sync::Arc; - pub mod regexpcount; pub mod regexpinstr; pub mod regexplike; @@ -124,3 +127,39 @@ pub fn functions() -> Vec> { regexp_replace(), ] } + +pub fn compile_and_cache_regex<'strings, 'cache>( + regex: &'strings str, + flags: Option<&'strings str>, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, +) -> Result<&'cache Regex, ArrowError> +where + 'strings: 'cache, +{ + let result = match regex_cache.entry((regex, flags)) { + Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), + Entry::Vacant(vacant_entry) => { + let compiled = compile_regex(regex, flags)?; + vacant_entry.insert(compiled) + } + }; + Ok(result) +} + +pub fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + format!("(?{flags}){regex}") + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!("Regular expression did not compile: {pattern}")) + }) +} diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 95146566de25..9a59cad74b5b 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::regex::{compile_and_cache_regex, compile_regex}; use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType}; use arrow::datatypes::{DataType, Int64Type}; use arrow::datatypes::{ @@ -29,10 +30,10 @@ use datafusion_expr::{ use datafusion_macros::user_doc; use itertools::izip; use regex::Regex; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::Arc; +// Ensure the `compile_and_cache_regex` function is defined in the `regex` module or imported correctly. #[user_doc( doc_section(label = "Regular Expression Functions"), description = "Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.", @@ -550,42 +551,6 @@ where } } -pub fn compile_and_cache_regex<'strings, 'cache>( - regex: &'strings str, - flags: Option<&'strings str>, - regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, -) -> Result<&'cache Regex, ArrowError> -where - 'strings: 'cache, -{ - let result = match regex_cache.entry((regex, flags)) { - Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), - Entry::Vacant(vacant_entry) => { - let compiled = compile_regex(regex, flags)?; - vacant_entry.insert(compiled) - } - }; - Ok(result) -} - -pub fn compile_regex(regex: &str, flags: Option<&str>) -> Result { - let pattern = match flags { - None | Some("") => regex.to_string(), - Some(flags) => { - if flags.contains("g") { - return Err(ArrowError::ComputeError( - "regexp_count() does not support global flag".to_string(), - )); - } - format!("(?{flags}){regex}") - } - }; - - Regex::new(&pattern).map_err(|_| { - ArrowError::ComputeError(format!("Regular expression did not compile: {pattern}")) - }) -} - fn count_matches( value: Option<&str>, pattern: &Regex, diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 1f9211fdde5d..dafd3cdf61d5 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use super::regexpcount::compile_and_cache_regex; - use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType}; use arrow::datatypes::{DataType, Int64Type}; use arrow::datatypes::{ @@ -34,16 +32,18 @@ use regex::Regex; use std::collections::HashMap; use std::sync::Arc; +use crate::regex::compile_and_cache_regex; + #[user_doc( doc_section(label = "Regular Expression Functions"), description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.", syntax_example = "regexp_instr(str, regexp[, start[, N[, flags]]])", sql_example = r#"```sql -> SELECT regexp_instr('ABCDEF', 'c(.)(..)'); +> SELECT regexp_instr('ABCDEF', 'C(.)(..)'); +---------------------------------------------------------------+ -| regexp_instr(Utf8("ABCDEF"),Utf8("c(.)(..)")) | +| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) | +---------------------------------------------------------------+ -| 2 | +| 3 | +---------------------------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 8ac736e1e33d..885c8bbd2b98 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1854,11 +1854,11 @@ regexp_instr(str, regexp[, start[, N[, flags]]]) #### Example ```sql -> SELECT regexp_instr('ABCDEF', 'c(.)(..)'); +> SELECT regexp_instr('ABCDEF', 'C(.)(..)'); +---------------------------------------------------------------+ -| regexp_instr(Utf8("ABCDEF"),Utf8("c(.)(..)")) | +| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) | +---------------------------------------------------------------+ -| 2 | +| 3 | +---------------------------------------------------------------+ ```