Skip to content

Commit af0b6f2

Browse files
committed
implemented compare between
1 parent c835000 commit af0b6f2

File tree

4 files changed

+44
-3
lines changed

4 files changed

+44
-3
lines changed

src/include/to_substrait.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class DuckDBToSubstrait {
8888
void TransformFunctionExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
8989
static void TransformConstantExpression(Expression &dexpr, substrait::Expression &sexpr);
9090
void TransformComparisonExpression(Expression &dexpr, substrait::Expression &sexpr);
91+
void TransformBetweenExpression(Expression &dexpr, substrait::Expression &sexpr);
9192
void TransformConjunctionExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
9293
void TransformNotNullExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
9394
void TransformIsNullExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);

src/to_substrait.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,29 @@ void DuckDBToSubstrait::TransformComparisonExpression(Expression &dexpr, substra
414414
*scalar_fun->mutable_output_type() = DuckToSubstraitType(dcomp.return_type);
415415
}
416416

417+
void DuckDBToSubstrait::TransformBetweenExpression(Expression &dexpr, substrait::Expression &sexpr) {
418+
auto &dcomp = dexpr.Cast<BoundBetweenExpression>();
419+
420+
if (dexpr.type != ExpressionType::COMPARE_BETWEEN) {
421+
throw InternalException("Not a between comparison expression");
422+
}
423+
424+
auto scalar_fun = sexpr.mutable_scalar_function();
425+
vector<::substrait::Type> args_types;
426+
args_types.emplace_back(DuckToSubstraitType(dcomp.input->return_type));
427+
args_types.emplace_back(DuckToSubstraitType(dcomp.lower->return_type));
428+
args_types.emplace_back(DuckToSubstraitType(dcomp.upper->return_type));
429+
scalar_fun->set_function_reference(RegisterFunction("between", args_types));
430+
431+
auto sarg = scalar_fun->add_arguments();
432+
TransformExpr(*dcomp.input, *sarg->mutable_value(), 0);
433+
sarg = scalar_fun->add_arguments();
434+
TransformExpr(*dcomp.lower, *sarg->mutable_value(), 0);
435+
sarg = scalar_fun->add_arguments();
436+
TransformExpr(*dcomp.upper, *sarg->mutable_value(), 0);
437+
*scalar_fun->mutable_output_type() = DuckToSubstraitType(dcomp.return_type);
438+
}
439+
417440
void DuckDBToSubstrait::TransformConjunctionExpression(Expression &dexpr, substrait::Expression &sexpr,
418441
uint64_t col_offset) {
419442
auto &dconj = dexpr.Cast<BoundConjunctionExpression>();
@@ -538,6 +561,9 @@ void DuckDBToSubstrait::TransformExpr(Expression &dexpr, substrait::Expression &
538561
case ExpressionType::COMPARE_NOT_DISTINCT_FROM:
539562
TransformComparisonExpression(dexpr, sexpr);
540563
break;
564+
case ExpressionType::COMPARE_BETWEEN:
565+
TransformBetweenExpression(dexpr, sexpr);
566+
break;
541567
case ExpressionType::CONJUNCTION_AND:
542568
case ExpressionType::CONJUNCTION_OR:
543569
TransformConjunctionExpression(dexpr, sexpr, col_offset);

test/sql/test_between.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# name: test/sql/test_between.test
2+
# description: Test BETWEEN comparison
3+
# group: [sql]
4+
5+
require substrait
6+
7+
statement ok
8+
PRAGMA enable_verification
9+
10+
statement ok
11+
create table t as select * from range(100) as t(x)
12+
13+
statement ok
14+
CALL get_substrait('select * from t where x BETWEEN 4 AND 6');

test/sql/test_substrait_tpcds.test

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ CALL dsdgen(sf=0.01)
2020
#statement ok
2121
#CALL get_substrait('SELECT cd_gender, cd_marital_status, cd_education_status, Count(*) cnt1, cd_purchase_estimate, Count(*) cnt2, cd_credit_rating, Count(*) cnt3, cd_dep_count, Count(*) cnt4, cd_dep_employed_count, Count(*) cnt5, cd_dep_college_count, Count(*) cnt6 FROM customer c, customer_address ca, customer_demographics WHERE c.c_current_addr_sk = ca.ca_address_sk AND ca_county IN ( ''Lycoming County'', ''Sheridan County'', ''Kandiyohi County'', ''Pike County'', ''Greene County'' ) AND cd_demo_sk = c.c_current_cdemo_sk AND EXISTS (SELECT * FROM store_sales, date_dim WHERE c.c_customer_sk = ss_customer_sk AND ss_sold_date_sk = d_date_sk AND d_year = 2002 AND d_moy BETWEEN 4 AND 4 + 3) AND ( EXISTS (SELECT * FROM web_sales, date_dim WHERE c.c_customer_sk = ws_bill_customer_sk AND ws_sold_date_sk = d_date_sk AND d_year = 2002 AND d_moy BETWEEN 4 AND 4 + 3) OR EXISTS (SELECT * FROM catalog_sales, date_dim WHERE c.c_customer_sk = cs_ship_customer_sk AND cs_sold_date_sk = d_date_sk AND d_year = 2002 AND d_moy BETWEEN 4 AND 4 + 3) ) GROUP BY cd_gender, cd_marital_status, cd_education_status, cd_purchase_estimate, cd_credit_rating, cd_dep_count, cd_dep_employed_count, cd_dep_college_count ORDER BY cd_gender, cd_marital_status, cd_education_status, cd_purchase_estimate, cd_credit_rating, cd_dep_count, cd_dep_employed_count, cd_dep_college_count LIMIT 100; ')
2222

23-
#Q 3 (COMPARE_BETWEEN)
24-
#statement ok
25-
#CALL get_substrait('WITH year_total AS (SELECT c_customer_id customer_id, c_first_name customer_first_name , c_last_name customer_last_name, c_preferred_cust_flag customer_preferred_cust_flag , c_birth_country customer_birth_country, c_login customer_login, c_email_address customer_email_address, d_year dyear, Sum(ss_ext_list_price - ss_ext_discount_amt) year_total, ''s'' sale_type FROM customer, store_sales, date_dim WHERE c_customer_sk = ss_customer_sk AND ss_sold_date_sk = d_date_sk GROUP BY c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, c_login, c_email_address, d_year UNION ALL SELECT c_customer_id customer_id, c_first_name customer_first_name , c_last_name customer_last_name, c_preferred_cust_flag customer_preferred_cust_flag , c_birth_country customer_birth_country, c_login customer_login, c_email_address customer_email_address, d_year dyear, Sum(ws_ext_list_price - ws_ext_discount_amt) year_total, ''w'' sale_type FROM customer, web_sales, date_dim WHERE c_customer_sk = ws_bill_customer_sk AND ws_sold_date_sk = d_date_sk GROUP BY c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, c_login, c_email_address, d_year) SELECT t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name, t_s_secyear.customer_birth_country FROM year_total t_s_firstyear, year_total t_s_secyear, year_total t_w_firstyear, year_total t_w_secyear WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id AND t_s_firstyear.customer_id = t_w_secyear.customer_id AND t_s_firstyear.customer_id = t_w_firstyear.customer_id AND t_s_firstyear.sale_type = ''s'' AND t_w_firstyear.sale_type = ''w'' AND t_s_secyear.sale_type = ''s'' AND t_w_secyear.sale_type = ''w'' AND t_s_firstyear.dyear = 2001 AND t_s_secyear.dyear = 2001 + 1 AND t_w_firstyear.dyear = 2001 AND t_w_secyear.dyear = 2001 + 1 AND t_s_firstyear.year_total > 0 AND t_w_firstyear.year_total > 0 AND CASE WHEN t_w_firstyear.year_total > 0 THEN t_w_secyear.year_total / t_w_firstyear.year_total ELSE 0.0 END > CASE WHEN t_s_firstyear.year_total > 0 THEN t_s_secyear.year_total / t_s_firstyear.year_total ELSE 0.0 END ORDER BY t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name, t_s_secyear.customer_birth_country LIMIT 100; ')
23+
#Q 3
24+
statement ok
25+
CALL get_substrait('WITH year_total AS (SELECT c_customer_id customer_id, c_first_name customer_first_name , c_last_name customer_last_name, c_preferred_cust_flag customer_preferred_cust_flag , c_birth_country customer_birth_country, c_login customer_login, c_email_address customer_email_address, d_year dyear, Sum(ss_ext_list_price - ss_ext_discount_amt) year_total, ''s'' sale_type FROM customer, store_sales, date_dim WHERE c_customer_sk = ss_customer_sk AND ss_sold_date_sk = d_date_sk GROUP BY c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, c_login, c_email_address, d_year UNION ALL SELECT c_customer_id customer_id, c_first_name customer_first_name , c_last_name customer_last_name, c_preferred_cust_flag customer_preferred_cust_flag , c_birth_country customer_birth_country, c_login customer_login, c_email_address customer_email_address, d_year dyear, Sum(ws_ext_list_price - ws_ext_discount_amt) year_total, ''w'' sale_type FROM customer, web_sales, date_dim WHERE c_customer_sk = ws_bill_customer_sk AND ws_sold_date_sk = d_date_sk GROUP BY c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, c_login, c_email_address, d_year) SELECT t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name, t_s_secyear.customer_birth_country FROM year_total t_s_firstyear, year_total t_s_secyear, year_total t_w_firstyear, year_total t_w_secyear WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id AND t_s_firstyear.customer_id = t_w_secyear.customer_id AND t_s_firstyear.customer_id = t_w_firstyear.customer_id AND t_s_firstyear.sale_type = ''s'' AND t_w_firstyear.sale_type = ''w'' AND t_s_secyear.sale_type = ''s'' AND t_w_secyear.sale_type = ''w'' AND t_s_firstyear.dyear = 2001 AND t_s_secyear.dyear = 2001 + 1 AND t_w_firstyear.dyear = 2001 AND t_w_secyear.dyear = 2001 + 1 AND t_s_firstyear.year_total > 0 AND t_w_firstyear.year_total > 0 AND CASE WHEN t_w_firstyear.year_total > 0 THEN t_w_secyear.year_total / t_w_firstyear.year_total ELSE 0.0 END > CASE WHEN t_s_firstyear.year_total > 0 THEN t_s_secyear.year_total / t_s_firstyear.year_total ELSE 0.0 END ORDER BY t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name, t_s_secyear.customer_birth_country LIMIT 100; ')
2626

2727
#Q 4 (WINDOW)
2828
#statement ok

0 commit comments

Comments
 (0)