Skip to content

Commit 321b74f

Browse files
authored
feat(cubesql): starts_with, ends_with, LOWER(?column) = ?literal (cube-js#5310)
1 parent 4b82c0d commit 321b74f

File tree

6 files changed

+425
-73
lines changed

6 files changed

+425
-73
lines changed

rust/cubesql/cubesql/src/compile/engine/udf.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,35 @@ pub fn create_timediff_udf() -> ScalarUDF {
676676
)
677677
}
678678

679+
pub fn create_ends_with_udf() -> ScalarUDF {
680+
let fun = make_scalar_function(move |args: &[ArrayRef]| {
681+
assert!(args.len() == 2);
682+
683+
let string_array = downcast_string_arg!(args[0], "string", i32);
684+
let prefix_array = downcast_string_arg!(args[1], "prefix", i32);
685+
686+
let result = string_array
687+
.iter()
688+
.zip(prefix_array.iter())
689+
.map(|(string, prefix)| match (string, prefix) {
690+
(Some(string), Some(prefix)) => Some(string.ends_with(prefix)),
691+
_ => None,
692+
})
693+
.collect::<BooleanArray>();
694+
695+
Ok(Arc::new(result) as ArrayRef)
696+
});
697+
698+
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Boolean)));
699+
700+
ScalarUDF::new(
701+
"ends_with",
702+
&Signature::exact(vec![DataType::Utf8, DataType::Utf8], Volatility::Immutable),
703+
&return_type,
704+
&fun,
705+
)
706+
}
707+
679708
// https://docs.aws.amazon.com/redshift/latest/dg/r_DATEDIFF_function.html
680709
pub fn create_datediff_udf() -> ScalarUDF {
681710
let fun = make_scalar_function(move |args: &[ArrayRef]| {

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ use self::{
4141
create_current_schemas_udf, create_current_timestamp_udf, create_current_user_udf,
4242
create_date_add_udf, create_date_sub_udf, create_date_udf, create_dateadd_udf,
4343
create_datediff_udf, create_dayofmonth_udf, create_dayofweek_udf, create_dayofyear_udf,
44-
create_db_udf, create_format_type_udf, create_generate_series_udtf,
45-
create_generate_subscripts_udtf, create_has_schema_privilege_udf, create_hour_udf,
46-
create_if_udf, create_instr_udf, create_interval_mul_udf, create_isnull_udf,
47-
create_json_build_object_udf, create_least_udf, create_locate_udf, create_makedate_udf,
48-
create_measure_udaf, create_minute_udf, create_pg_backend_pid_udf,
49-
create_pg_datetime_precision_udf, create_pg_expandarray_udtf,
50-
create_pg_get_constraintdef_udf, create_pg_get_expr_udf,
44+
create_db_udf, create_ends_with_udf, create_format_type_udf,
45+
create_generate_series_udtf, create_generate_subscripts_udtf,
46+
create_has_schema_privilege_udf, create_hour_udf, create_if_udf, create_instr_udf,
47+
create_interval_mul_udf, create_isnull_udf, create_json_build_object_udf,
48+
create_least_udf, create_locate_udf, create_makedate_udf, create_measure_udaf,
49+
create_minute_udf, create_pg_backend_pid_udf, create_pg_datetime_precision_udf,
50+
create_pg_expandarray_udtf, create_pg_get_constraintdef_udf, create_pg_get_expr_udf,
5151
create_pg_get_serial_sequence_udf, create_pg_get_userbyid_udf,
5252
create_pg_is_other_temp_schema, create_pg_my_temp_schema,
5353
create_pg_numeric_precision_udf, create_pg_numeric_scale_udf,
@@ -1128,6 +1128,7 @@ WHERE `TABLE_SCHEMA` = '{}'",
11281128
ctx.register_udf(create_json_build_object_udf());
11291129
ctx.register_udf(create_regexp_substr_udf());
11301130
ctx.register_udf(create_interval_mul_udf());
1131+
ctx.register_udf(create_ends_with_udf());
11311132

11321133
// udaf
11331134
ctx.register_udaf(create_measure_udaf());
@@ -1766,6 +1767,112 @@ mod tests {
17661767
)
17671768
}
17681769

1770+
#[tokio::test]
1771+
async fn test_starts_with() {
1772+
let query_plan = convert_select_to_query_plan(
1773+
"SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce WHERE starts_with(customer_gender, 'fe')"
1774+
.to_string(),
1775+
DatabaseProtocol::PostgreSQL,
1776+
)
1777+
.await;
1778+
1779+
let cube_scan = query_plan.as_logical_plan().find_cube_scan();
1780+
1781+
assert_eq!(
1782+
cube_scan.request,
1783+
V1LoadRequestQuery {
1784+
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string(),]),
1785+
segments: Some(vec![]),
1786+
dimensions: Some(vec![]),
1787+
time_dimensions: None,
1788+
order: None,
1789+
limit: None,
1790+
offset: None,
1791+
filters: Some(vec![V1LoadRequestQueryFilterItem {
1792+
member: Some("KibanaSampleDataEcommerce.customer_gender".to_string()),
1793+
operator: Some("startsWith".to_string()),
1794+
values: Some(vec!["fe".to_string()]),
1795+
or: None,
1796+
and: None
1797+
}])
1798+
}
1799+
)
1800+
}
1801+
1802+
#[tokio::test]
1803+
async fn test_ends_with_query() {
1804+
let query_plan = convert_select_to_query_plan(
1805+
"SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce WHERE ends_with(customer_gender, 'emale')"
1806+
.to_string(),
1807+
DatabaseProtocol::PostgreSQL,
1808+
)
1809+
.await;
1810+
1811+
let cube_scan = query_plan.as_logical_plan().find_cube_scan();
1812+
1813+
assert_eq!(
1814+
cube_scan.request,
1815+
V1LoadRequestQuery {
1816+
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string(),]),
1817+
segments: Some(vec![]),
1818+
dimensions: Some(vec![]),
1819+
time_dimensions: None,
1820+
order: None,
1821+
limit: None,
1822+
offset: None,
1823+
filters: Some(vec![V1LoadRequestQueryFilterItem {
1824+
member: Some("KibanaSampleDataEcommerce.customer_gender".to_string()),
1825+
operator: Some("endsWith".to_string()),
1826+
values: Some(vec!["emale".to_string()]),
1827+
or: None,
1828+
and: None
1829+
}])
1830+
}
1831+
)
1832+
}
1833+
1834+
#[tokio::test]
1835+
async fn test_lower_equals_thoughtspot() {
1836+
let query_plan = convert_select_to_query_plan(
1837+
"SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce WHERE LOWER(customer_gender) = 'female'"
1838+
.to_string(),
1839+
DatabaseProtocol::PostgreSQL,
1840+
)
1841+
.await;
1842+
1843+
let cube_scan = query_plan.as_logical_plan().find_cube_scan();
1844+
1845+
assert_eq!(
1846+
cube_scan.request,
1847+
V1LoadRequestQuery {
1848+
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string(),]),
1849+
segments: Some(vec![]),
1850+
dimensions: Some(vec![]),
1851+
time_dimensions: None,
1852+
order: None,
1853+
limit: None,
1854+
offset: None,
1855+
// TODO: Migrate to equalsLower operator, when it will be available in Cube?
1856+
filters: Some(vec![
1857+
V1LoadRequestQueryFilterItem {
1858+
member: Some("KibanaSampleDataEcommerce.customer_gender".to_string()),
1859+
operator: Some("startsWith".to_string()),
1860+
values: Some(vec!["female".to_string()]),
1861+
or: None,
1862+
and: None
1863+
},
1864+
V1LoadRequestQueryFilterItem {
1865+
member: Some("KibanaSampleDataEcommerce.customer_gender".to_string()),
1866+
operator: Some("endsWith".to_string()),
1867+
values: Some(vec!["female".to_string()]),
1868+
or: None,
1869+
and: None
1870+
}
1871+
])
1872+
}
1873+
)
1874+
}
1875+
17691876
#[tokio::test]
17701877
async fn test_change_user_via_in_filter_thoughtspot() {
17711878
let query_plan = convert_select_to_query_plan(
@@ -5219,6 +5326,24 @@ ORDER BY \"COUNT(count)\" DESC"
52195326
Ok(())
52205327
}
52215328

5329+
#[tokio::test]
5330+
async fn test_ends_with() -> Result<(), CubeError> {
5331+
insta::assert_snapshot!(
5332+
"ends_with",
5333+
execute_query(
5334+
"select \
5335+
ends_with('rust is killing me', 'me') as r1,
5336+
ends_with('rust is killing me', 'no') as r2
5337+
"
5338+
.to_string(),
5339+
DatabaseProtocol::MySQL
5340+
)
5341+
.await?
5342+
);
5343+
5344+
Ok(())
5345+
}
5346+
52225347
#[tokio::test]
52235348
async fn test_locate() -> Result<(), CubeError> {
52245349
assert_eq!(

rust/cubesql/cubesql/src/compile/rewrite/analysis.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,10 @@ impl LogicalPlanAnalysis {
276276
push_referenced_columns(params[1], &mut vec)?;
277277
Some(vec)
278278
}
279+
LogicalPlanLanguage::ScalarUDFExpr(params) => {
280+
push_referenced_columns(params[1], &mut vec)?;
281+
Some(vec)
282+
}
279283
LogicalPlanLanguage::AggregateFunctionExpr(params) => {
280284
push_referenced_columns(params[1], &mut vec)?;
281285
Some(vec)
@@ -297,7 +301,8 @@ impl LogicalPlanAnalysis {
297301
| LogicalPlanLanguage::CaseExprElseExpr(params)
298302
| LogicalPlanLanguage::CaseExprExpr(params)
299303
| LogicalPlanLanguage::AggregateFunctionExprArgs(params)
300-
| LogicalPlanLanguage::ScalarFunctionExprArgs(params) => {
304+
| LogicalPlanLanguage::ScalarFunctionExprArgs(params)
305+
| LogicalPlanLanguage::ScalarUDFExprArgs(params) => {
301306
for p in params.iter() {
302307
vec.extend(referenced_columns(*p)?.into_iter());
303308
}

rust/cubesql/cubesql/src/compile/rewrite/mod.rs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,24 @@ fn list_expr(list_type: impl Display, list: Vec<impl Display>) -> String {
508508
}
509509

510510
fn udf_expr(fun_name: impl Display, args: Vec<impl Display>) -> String {
511-
format!(
512-
"(ScalarUDFExpr ScalarUDFExprFun:{} {})",
513-
fun_name,
514-
list_expr("ScalarUDFExprArgs", args)
515-
)
511+
udf_expr_var_arg(fun_name, list_expr("ScalarUDFExprArgs", args))
512+
}
513+
514+
fn udf_expr_var_arg(fun_name: impl Display, arg_list: impl Display) -> String {
515+
let prefix = if fun_name.to_string().starts_with("?") {
516+
""
517+
} else {
518+
"ScalarUDFExprFun:"
519+
};
520+
format!("(ScalarUDFExpr {}{} {})", prefix, fun_name, arg_list)
521+
}
522+
523+
fn udf_fun_expr_args(left: impl Display, right: impl Display) -> String {
524+
format!("(ScalarUDFExprArgs {} {})", left, right)
525+
}
526+
527+
fn udf_fun_expr_args_empty_tail() -> String {
528+
"ScalarUDFExprArgs".to_string()
516529
}
517530

518531
fn fun_expr(fun_name: impl Display, args: Vec<impl Display>) -> String {

0 commit comments

Comments
 (0)