Skip to content

Commit 639638f

Browse files
authored
fix(query): fix check join condition with aggregate and window functions (#13879)
1 parent e5db3a9 commit 639638f

File tree

8 files changed

+75
-41
lines changed

8 files changed

+75
-41
lines changed

src/query/expression/src/evaluator.rs

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -922,14 +922,15 @@ impl<'a> Evaluator<'a> {
922922
let result = evaluator.run(&expr)?;
923923
let result_col = result.convert_to_full_column(expr.data_type(), c.len());
924924

925-
if func_name == "array_filter" {
925+
let val = if func_name == "array_filter" {
926926
let result_col = result_col.remove_nullable();
927927
let bitmap = result_col.as_boolean().unwrap();
928928
let filtered_inner_col = c.filter(bitmap);
929-
Ok(Value::Scalar(Scalar::Array(filtered_inner_col)))
929+
Value::Scalar(Scalar::Array(filtered_inner_col))
930930
} else {
931-
Ok(Value::Scalar(Scalar::Array(result_col)))
932-
}
931+
Value::Scalar(Scalar::Array(result_col))
932+
};
933+
Ok(val)
933934
}
934935
_ => unreachable!(),
935936
},
@@ -959,7 +960,7 @@ impl<'a> Evaluator<'a> {
959960
let result = evaluator.run(&expr)?;
960961
let result_col = result.convert_to_full_column(expr.data_type(), inner_col.len());
961962

962-
let col = if func_name == "array_filter" {
963+
let array_col = if func_name == "array_filter" {
963964
let result_col = result_col.remove_nullable();
964965
let bitmap = result_col.as_boolean().unwrap();
965966
let filtered_inner_col = inner_col.filter(bitmap);
@@ -975,33 +976,22 @@ impl<'a> Evaluator<'a> {
975976
filtered_offsets.push(new_offset);
976977
}
977978

978-
let array_col = Column::Array(Box::new(ArrayColumn {
979+
Column::Array(Box::new(ArrayColumn {
979980
values: filtered_inner_col,
980981
offsets: filtered_offsets.into(),
981-
}));
982-
match validity {
983-
Some(validity) => {
984-
Value::Column(Column::Nullable(Box::new(NullableColumn {
985-
column: array_col,
986-
validity,
987-
})))
988-
}
989-
None => Value::Column(array_col),
990-
}
982+
}))
991983
} else {
992-
let array_col = Column::Array(Box::new(ArrayColumn {
984+
Column::Array(Box::new(ArrayColumn {
993985
values: result_col,
994986
offsets,
995-
}));
996-
match validity {
997-
Some(validity) => {
998-
Value::Column(Column::Nullable(Box::new(NullableColumn {
999-
column: array_col,
1000-
validity,
1001-
})))
1002-
}
1003-
None => Value::Column(array_col),
1004-
}
987+
}))
988+
};
989+
let col = match validity {
990+
Some(validity) => Value::Column(Column::Nullable(Box::new(NullableColumn {
991+
column: array_col,
992+
validity,
993+
}))),
994+
None => Value::Column(array_col),
1005995
};
1006996
Ok(col)
1007997
}

src/query/expression/src/function.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,12 +391,7 @@ impl FunctionRegistry {
391391
pub fn get_property(&self, func_name: &str) -> Option<FunctionProperty> {
392392
let func_name = func_name.to_lowercase();
393393
if self.contains(&func_name) {
394-
Some(
395-
self.properties
396-
.get(&func_name.to_lowercase())
397-
.cloned()
398-
.unwrap_or_default(),
399-
)
394+
Some(self.properties.get(&func_name).cloned().unwrap_or_default())
400395
} else {
401396
None
402397
}

src/query/sql/src/planner/binder/join.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use common_exception::Result;
2727
use common_exception::Span;
2828
use indexmap::IndexMap;
2929

30+
use super::Finder;
3031
use crate::binder::CteInfo;
3132
use crate::binder::JoinPredicate;
3233
use crate::binder::Visibility;
@@ -44,6 +45,7 @@ use crate::plans::Filter;
4445
use crate::plans::Join;
4546
use crate::plans::JoinType;
4647
use crate::plans::ScalarExpr;
48+
use crate::plans::Visitor;
4749
use crate::BindContext;
4850
use crate::IndexType;
4951
use crate::MetadataRef;
@@ -521,6 +523,36 @@ impl<'a> JoinConditionResolver<'a> {
521523
);
522524
}
523525
}
526+
527+
self.check_join_allowed_scalar_expr(left_join_conditions)
528+
.await?;
529+
self.check_join_allowed_scalar_expr(right_join_conditions)
530+
.await?;
531+
self.check_join_allowed_scalar_expr(non_equi_conditions)
532+
.await?;
533+
self.check_join_allowed_scalar_expr(other_join_conditions)
534+
.await?;
535+
536+
Ok(())
537+
}
538+
539+
async fn check_join_allowed_scalar_expr(&mut self, scalars: &Vec<ScalarExpr>) -> Result<()> {
540+
let f = |scalar: &ScalarExpr| {
541+
matches!(
542+
scalar,
543+
ScalarExpr::WindowFunction(_) | ScalarExpr::AggregateFunction(_)
544+
)
545+
};
546+
for scalar in scalars {
547+
let mut finder = Finder::new(&f);
548+
finder.visit(scalar)?;
549+
if !finder.scalars().is_empty() {
550+
return Err(ErrorCode::SemanticError(
551+
"Join condition can't contain aggregate or window functions".to_string(),
552+
)
553+
.set_span(scalar.span()));
554+
}
555+
}
524556
Ok(())
525557
}
526558

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ impl<'a> TypeChecker<'a> {
760760
&& !GENERAL_WINDOW_FUNCTIONS.contains(&func_name)
761761
{
762762
return Err(ErrorCode::SemanticError(
763-
"only general and aggregate functions allowed in window syntax",
763+
"only window and aggregate functions allowed in window syntax",
764764
)
765765
.set_span(*span));
766766
}
@@ -821,7 +821,7 @@ impl<'a> TypeChecker<'a> {
821821
// general window function
822822
if window.is_none() {
823823
return Err(ErrorCode::SemanticError(format!(
824-
"window function {name} can only be used in window clause"
824+
"window function {func_name} can only be used in window clause"
825825
)));
826826
}
827827
let func = self
@@ -882,19 +882,19 @@ impl<'a> TypeChecker<'a> {
882882
let params = lambda
883883
.params
884884
.iter()
885-
.map(|param| param.name.clone())
885+
.map(|param| param.name.to_lowercase())
886886
.collect::<Vec<_>>();
887887

888888
// TODO: support multiple params
889889
if params.len() != 1 {
890890
return Err(ErrorCode::SemanticError(format!(
891-
"incorrect number of parameters in lambda function, {name} expects 1 parameter",
891+
"incorrect number of parameters in lambda function, {func_name} expects 1 parameter",
892892
)));
893893
}
894894

895895
if args.len() != 1 {
896896
return Err(ErrorCode::SemanticError(format!(
897-
"invalid arguments for lambda function, {name} expects 1 argument"
897+
"invalid arguments for lambda function, {func_name} expects 1 argument"
898898
)));
899899
}
900900
let box (arg, arg_type) = self.resolve(args[0]).await?;

src/query/storages/system/src/table_functions_table.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ use common_exception::Result;
2020
use common_expression::types::StringType;
2121
use common_expression::utils::FromData;
2222
use common_expression::DataBlock;
23+
use common_expression::FunctionKind;
2324
use common_expression::TableDataType;
2425
use common_expression::TableField;
2526
use common_expression::TableSchemaRefExt;
27+
use common_functions::BUILTIN_FUNCTIONS;
2628
use common_meta_app::schema::TableIdent;
2729
use common_meta_app::schema::TableInfo;
2830
use common_meta_app::schema::TableMeta;
@@ -43,7 +45,16 @@ impl SyncSystemTable for TableFunctionsTable {
4345

4446
fn get_full_data(&self, ctx: Arc<dyn TableContext>) -> Result<DataBlock> {
4547
let func_names = ctx.get_default_catalog()?.list_table_functions();
46-
let names = func_names.iter().map(|s| s.as_str()).collect::<Vec<_>>();
48+
let mut names = func_names.iter().map(|s| s.as_str()).collect::<Vec<_>>();
49+
// srf functions can also used as table functions
50+
let mut srf_func_names = BUILTIN_FUNCTIONS
51+
.properties
52+
.iter()
53+
.filter(|(_, property)| property.kind == FunctionKind::SRF)
54+
.map(|(name, _)| name.as_str())
55+
.collect::<Vec<_>>();
56+
names.append(&mut srf_func_names);
57+
4758
Ok(DataBlock::new_from_columns(vec![StringType::from_data(
4859
names,
4960
)]))

src/tests/sqlsmith/src/sql_gen/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
528528
let mut expr = self.gen_expr(ty);
529529
let len = self.rng.gen_range(1..=3);
530530
for _ in 0..len {
531-
let accessor = match self.rng.gen_range(0..=3) {
531+
let accessor = match self.rng.gen_range(0..=2) {
532532
0 => MapAccessor::Bracket {
533533
key: Box::new(self.gen_expr(&DataType::Number(NumberDataType::UInt8))),
534534
},

tests/sqllogictests/suites/query/02_function/02_0061_function_array.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ select array_apply(array_apply([5, NULL, 6], x -> COALESCE(x, 0) + 1), y -> y +
221221
[16,11,17]
222222

223223
query TT
224-
select array_transform(col1, a -> a * 2), array_apply(col2, b -> upper(b)) from t
224+
select array_transform(col1, A -> a * 2), array_apply(col2, B -> upper(B)) from t
225225
----
226226
[2,4,6,6] ['X','X','Y','Z']
227227

tests/sqllogictests/suites/query/join/join.test

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,12 @@ SELECT * FROM t JOIN t1, t as t2 JOIN t1 as t3;
184184
2 3 d 2 1 c
185185
2 3 d 2 3 d
186186

187+
statement error 1065
188+
SELECT * FROM t JOIN t1 on t.id = max(t1.id)
189+
190+
statement error 1065
191+
SELECT * FROM t JOIN t1 on last(t1.id) over(partition by 10)
192+
187193
statement ok
188194
drop table t;
189195

0 commit comments

Comments
 (0)