diff --git a/src/query/expression/src/converts/arrow/to.rs b/src/query/expression/src/converts/arrow/to.rs index 5316a3d65ce77..c6d2b609bf45a 100644 --- a/src/query/expression/src/converts/arrow/to.rs +++ b/src/query/expression/src/converts/arrow/to.rs @@ -26,6 +26,7 @@ use arrow_schema::Schema; use arrow_schema::TimeUnit; use databend_common_column::bitmap::Bitmap; use databend_common_column::buffer::buffer_to_array_data; +use databend_common_exception::ErrorCode; use databend_common_exception::Result; use super::ARROW_EXT_TYPE_BITMAP; @@ -230,6 +231,14 @@ impl DataBlock { } pub fn to_record_batch(self, table_schema: &TableSchema) -> Result { + if self.columns().len() != table_schema.num_fields() { + return Err(ErrorCode::Internal(format!( + "The number of columns in the data block does not match the number of fields in the table schema, block_columns: {}, table_schema_fields: {}", + self.columns().len(), + table_schema.num_fields() + ))); + } + if table_schema.num_fields() == 0 { return Ok(RecordBatch::try_new_with_options( Arc::new(Schema::empty()), diff --git a/src/query/service/src/pipelines/builders/builder_column_mutation.rs b/src/query/service/src/pipelines/builders/builder_column_mutation.rs index 071aa04b199d4..f3c5e24d9074b 100644 --- a/src/query/service/src/pipelines/builders/builder_column_mutation.rs +++ b/src/query/service/src/pipelines/builders/builder_column_mutation.rs @@ -38,6 +38,7 @@ impl PipelineBuilder { column_mutation.field_id_to_schema_index.clone(), column_mutation.input_num_columns, column_mutation.has_filter_column, + column_mutation.udf_col_num, )?; } @@ -79,6 +80,7 @@ impl PipelineBuilder { mut field_id_to_schema_index: HashMap, num_input_columns: usize, has_filter_column: bool, + udf_col_num: usize, ) -> Result<()> { let mut block_operators = Vec::new(); let mut next_column_offset = num_input_columns; @@ -129,7 +131,7 @@ impl PipelineBuilder { } // Keep the original order of the columns. - let num_output_columns = num_input_columns - has_filter_column as usize; + let num_output_columns = num_input_columns - has_filter_column as usize - udf_col_num; let mut projection = Vec::with_capacity(num_output_columns); for idx in 0..num_output_columns { if let Some(index) = schema_offset_to_new_offset.get(&idx) { diff --git a/src/query/sql/src/executor/physical_plans/physical_column_mutation.rs b/src/query/sql/src/executor/physical_plans/physical_column_mutation.rs index d3475e796ee16..b14d0f0e19c08 100644 --- a/src/query/sql/src/executor/physical_plans/physical_column_mutation.rs +++ b/src/query/sql/src/executor/physical_plans/physical_column_mutation.rs @@ -33,4 +33,5 @@ pub struct ColumnMutation { pub input_num_columns: usize, pub has_filter_column: bool, pub table_meta_timestamps: TableMetaTimestamps, + pub udf_col_num: usize, } diff --git a/src/query/sql/src/executor/physical_plans/physical_mutation.rs b/src/query/sql/src/executor/physical_plans/physical_mutation.rs index 9feb40c125a1d..8179ad56c2998 100644 --- a/src/query/sql/src/executor/physical_plans/physical_mutation.rs +++ b/src/query/sql/src/executor/physical_plans/physical_mutation.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::BTreeSet; use std::collections::HashMap; use std::sync::Arc; @@ -100,7 +101,7 @@ impl PhysicalPlanBuilder { &mut self, s_expr: &SExpr, mutation: &crate::plans::Mutation, - required: ColumnSet, + mut required: ColumnSet, ) -> Result { let crate::plans::Mutation { bind_context, @@ -122,9 +123,38 @@ impl PhysicalPlanBuilder { can_try_update_column_only, no_effect, truncate_table, + direct_filter, .. } = mutation; + let mut maybe_udfs = BTreeSet::new(); + for matched_evaluator in matched_evaluators { + if let Some(condition) = &matched_evaluator.condition { + maybe_udfs.extend(condition.used_columns()); + } + if let Some(update_list) = &matched_evaluator.update { + for update_scalar in update_list.values() { + maybe_udfs.extend(update_scalar.used_columns()); + } + } + } + for unmatched_evaluator in unmatched_evaluators { + if let Some(condition) = &unmatched_evaluator.condition { + maybe_udfs.extend(condition.used_columns()); + } + for value in &unmatched_evaluator.values { + maybe_udfs.extend(value.used_columns()); + } + } + for filter_value in direct_filter { + maybe_udfs.extend(filter_value.used_columns()); + } + + let udf_ids = s_expr.get_udfs_col_ids()?; + let required_udf_ids: BTreeSet<_> = maybe_udfs.intersection(&udf_ids).collect(); + let udf_col_num = required_udf_ids.len(); + required.extend(required_udf_ids); + let mut plan = self.build(s_expr.child(0)?, required).await?; if *no_effect { return Ok(plan); @@ -220,6 +250,7 @@ impl PhysicalPlanBuilder { input_num_columns: mutation_input_schema.fields().len(), has_filter_column: predicate_column_index.is_some(), table_meta_timestamps: mutation_build_info.table_meta_timestamps, + udf_col_num, }); if *distributed { diff --git a/src/query/sql/src/planner/optimizer/ir/expr/s_expr.rs b/src/query/sql/src/planner/optimizer/ir/expr/s_expr.rs index 0a992b794f848..118372f96546c 100644 --- a/src/query/sql/src/planner/optimizer/ir/expr/s_expr.rs +++ b/src/query/sql/src/planner/optimizer/ir/expr/s_expr.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::BTreeSet; use std::collections::HashSet; use std::sync::Arc; use std::sync::Mutex; @@ -224,6 +225,21 @@ impl SExpr { Ok(udfs) } + #[recursive::recursive] + pub fn get_udfs_col_ids(&self) -> Result> { + let mut udf_ids = BTreeSet::new(); + if let RelOperator::Udf(udf) = self.plan.as_ref() { + for item in udf.items.iter() { + udf_ids.insert(item.index); + } + } + for child in &self.children { + let udfs = child.get_udfs_col_ids()?; + udf_ids.extend(udfs); + } + Ok(udf_ids) + } + // Add column index to Scan nodes that match the given table index pub fn add_column_index_to_scans( &self, diff --git a/tests/sqllogictests/suites/udf_server/udf_server_test.test b/tests/sqllogictests/suites/udf_server/udf_server_test.test index d9f0ad3acbda8..d5a0eea86655c 100644 --- a/tests/sqllogictests/suites/udf_server/udf_server_test.test +++ b/tests/sqllogictests/suites/udf_server/udf_server_test.test @@ -564,6 +564,51 @@ select * from _tmp_table order by field1; 4 5 +statement ok +CREATE OR REPLACE TABLE test_update_udf(url STRING, length INT64); + +statement ok +INSERT INTO test_update_udf (url) VALUES('databend.com'),('databend.cn'); + +statement ok +UPDATE test_update_udf SET length = url_len(url); + +query TI +SELECT * FROM test_update_udf; +---- +databend.com 12 +databend.cn 11 + + +statement ok +CREATE OR REPLACE TABLE test_update_udf_1(url STRING, a INT64,b INT64,c INT64); + +statement ok +CREATE OR REPLACE FUNCTION url_len_mul_100 (VARCHAR) RETURNS BIGINT LANGUAGE python IMMUTABLE HANDLER = 'url_len_mul_100' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +INSERT INTO test_update_udf_1 (url) VALUES('databend.com'),('databend.cn'); + +statement ok +UPDATE test_update_udf_1 SET a = url_len(url),b = url_len_mul_100(url), c = length(url) + 123; + +query TIII +SELECT * FROM test_update_udf_1; +---- +databend.com 12 1200 135 +databend.cn 11 1100 134 + +statement ok +UPDATE test_update_udf_1 SET b = url_len(url),c = url_len_mul_100(url), a = length(url) + 123; + +query TIII +SELECT * FROM test_update_udf_1; +---- +databend.com 135 12 1200 +databend.cn 134 11 1100 + + + query I SELECT url_len('databend.com'); ---- diff --git a/tests/udf/udf_server.py b/tests/udf/udf_server.py index f34a107961182..c13f9448b4c25 100644 --- a/tests/udf/udf_server.py +++ b/tests/udf/udf_server.py @@ -190,6 +190,9 @@ def json_access(data: Any, key: str) -> Any: def url_len(key: str) -> int: return len(key) +@udf(input_types=["VARCHAR"], result_type="BIGINT") +def url_len_mul_100(key: str) -> int: + return len(key) * 100 @udf(input_types=["ARRAY(VARIANT)"], result_type="VARIANT") def json_concat(list: List[Any]) -> Any: @@ -449,6 +452,7 @@ def embedding_4(s: str): udf_server.add_function(wait) udf_server.add_function(wait_concurrent) udf_server.add_function(url_len) + udf_server.add_function(url_len_mul_100) udf_server.add_function(check_headers) udf_server.add_function(embedding_4)