Skip to content

Commit 3d6f946

Browse files
authored
fix: update with udf report error (#18397)
* fix: update with udf report error * try fix * add defensive check * fix * add more test
1 parent d577711 commit 3d6f946

File tree

7 files changed

+110
-2
lines changed

7 files changed

+110
-2
lines changed

src/query/expression/src/converts/arrow/to.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use arrow_schema::Schema;
2626
use arrow_schema::TimeUnit;
2727
use databend_common_column::bitmap::Bitmap;
2828
use databend_common_column::buffer::buffer_to_array_data;
29+
use databend_common_exception::ErrorCode;
2930
use databend_common_exception::Result;
3031

3132
use super::ARROW_EXT_TYPE_BITMAP;
@@ -230,6 +231,14 @@ impl DataBlock {
230231
}
231232

232233
pub fn to_record_batch(self, table_schema: &TableSchema) -> Result<RecordBatch> {
234+
if self.columns().len() != table_schema.num_fields() {
235+
return Err(ErrorCode::Internal(format!(
236+
"The number of columns in the data block does not match the number of fields in the table schema, block_columns: {}, table_schema_fields: {}",
237+
self.columns().len(),
238+
table_schema.num_fields()
239+
)));
240+
}
241+
233242
if table_schema.num_fields() == 0 {
234243
return Ok(RecordBatch::try_new_with_options(
235244
Arc::new(Schema::empty()),

src/query/service/src/pipelines/builders/builder_column_mutation.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ impl PipelineBuilder {
3838
column_mutation.field_id_to_schema_index.clone(),
3939
column_mutation.input_num_columns,
4040
column_mutation.has_filter_column,
41+
column_mutation.udf_col_num,
4142
)?;
4243
}
4344

@@ -79,6 +80,7 @@ impl PipelineBuilder {
7980
mut field_id_to_schema_index: HashMap<usize, usize>,
8081
num_input_columns: usize,
8182
has_filter_column: bool,
83+
udf_col_num: usize,
8284
) -> Result<()> {
8385
let mut block_operators = Vec::new();
8486
let mut next_column_offset = num_input_columns;
@@ -129,7 +131,7 @@ impl PipelineBuilder {
129131
}
130132

131133
// Keep the original order of the columns.
132-
let num_output_columns = num_input_columns - has_filter_column as usize;
134+
let num_output_columns = num_input_columns - has_filter_column as usize - udf_col_num;
133135
let mut projection = Vec::with_capacity(num_output_columns);
134136
for idx in 0..num_output_columns {
135137
if let Some(index) = schema_offset_to_new_offset.get(&idx) {

src/query/sql/src/executor/physical_plans/physical_column_mutation.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@ pub struct ColumnMutation {
3333
pub input_num_columns: usize,
3434
pub has_filter_column: bool,
3535
pub table_meta_timestamps: TableMetaTimestamps,
36+
pub udf_col_num: usize,
3637
}

src/query/sql/src/executor/physical_plans/physical_mutation.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::collections::BTreeSet;
1516
use std::collections::HashMap;
1617
use std::sync::Arc;
1718

@@ -100,7 +101,7 @@ impl PhysicalPlanBuilder {
100101
&mut self,
101102
s_expr: &SExpr,
102103
mutation: &crate::plans::Mutation,
103-
required: ColumnSet,
104+
mut required: ColumnSet,
104105
) -> Result<PhysicalPlan> {
105106
let crate::plans::Mutation {
106107
bind_context,
@@ -122,9 +123,38 @@ impl PhysicalPlanBuilder {
122123
can_try_update_column_only,
123124
no_effect,
124125
truncate_table,
126+
direct_filter,
125127
..
126128
} = mutation;
127129

130+
let mut maybe_udfs = BTreeSet::new();
131+
for matched_evaluator in matched_evaluators {
132+
if let Some(condition) = &matched_evaluator.condition {
133+
maybe_udfs.extend(condition.used_columns());
134+
}
135+
if let Some(update_list) = &matched_evaluator.update {
136+
for update_scalar in update_list.values() {
137+
maybe_udfs.extend(update_scalar.used_columns());
138+
}
139+
}
140+
}
141+
for unmatched_evaluator in unmatched_evaluators {
142+
if let Some(condition) = &unmatched_evaluator.condition {
143+
maybe_udfs.extend(condition.used_columns());
144+
}
145+
for value in &unmatched_evaluator.values {
146+
maybe_udfs.extend(value.used_columns());
147+
}
148+
}
149+
for filter_value in direct_filter {
150+
maybe_udfs.extend(filter_value.used_columns());
151+
}
152+
153+
let udf_ids = s_expr.get_udfs_col_ids()?;
154+
let required_udf_ids: BTreeSet<_> = maybe_udfs.intersection(&udf_ids).collect();
155+
let udf_col_num = required_udf_ids.len();
156+
required.extend(required_udf_ids);
157+
128158
let mut plan = self.build(s_expr.child(0)?, required).await?;
129159
if *no_effect {
130160
return Ok(plan);
@@ -220,6 +250,7 @@ impl PhysicalPlanBuilder {
220250
input_num_columns: mutation_input_schema.fields().len(),
221251
has_filter_column: predicate_column_index.is_some(),
222252
table_meta_timestamps: mutation_build_info.table_meta_timestamps,
253+
udf_col_num,
223254
});
224255

225256
if *distributed {

src/query/sql/src/planner/optimizer/ir/expr/s_expr.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::collections::BTreeSet;
1516
use std::collections::HashSet;
1617
use std::sync::Arc;
1718
use std::sync::Mutex;
@@ -224,6 +225,21 @@ impl SExpr {
224225
Ok(udfs)
225226
}
226227

228+
#[recursive::recursive]
229+
pub fn get_udfs_col_ids(&self) -> Result<BTreeSet<IndexType>> {
230+
let mut udf_ids = BTreeSet::new();
231+
if let RelOperator::Udf(udf) = self.plan.as_ref() {
232+
for item in udf.items.iter() {
233+
udf_ids.insert(item.index);
234+
}
235+
}
236+
for child in &self.children {
237+
let udfs = child.get_udfs_col_ids()?;
238+
udf_ids.extend(udfs);
239+
}
240+
Ok(udf_ids)
241+
}
242+
227243
// Add column index to Scan nodes that match the given table index
228244
pub fn add_column_index_to_scans(
229245
&self,

tests/sqllogictests/suites/udf_server/udf_server_test.test

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,51 @@ select * from _tmp_table order by field1;
564564
4
565565
5
566566

567+
statement ok
568+
CREATE OR REPLACE TABLE test_update_udf(url STRING, length INT64);
569+
570+
statement ok
571+
INSERT INTO test_update_udf (url) VALUES('databend.com'),('databend.cn');
572+
573+
statement ok
574+
UPDATE test_update_udf SET length = url_len(url);
575+
576+
query TI
577+
SELECT * FROM test_update_udf;
578+
----
579+
databend.com 12
580+
databend.cn 11
581+
582+
583+
statement ok
584+
CREATE OR REPLACE TABLE test_update_udf_1(url STRING, a INT64,b INT64,c INT64);
585+
586+
statement ok
587+
CREATE OR REPLACE FUNCTION url_len_mul_100 (VARCHAR) RETURNS BIGINT LANGUAGE python IMMUTABLE HANDLER = 'url_len_mul_100' ADDRESS = 'http://0.0.0.0:8815';
588+
589+
statement ok
590+
INSERT INTO test_update_udf_1 (url) VALUES('databend.com'),('databend.cn');
591+
592+
statement ok
593+
UPDATE test_update_udf_1 SET a = url_len(url),b = url_len_mul_100(url), c = length(url) + 123;
594+
595+
query TIII
596+
SELECT * FROM test_update_udf_1;
597+
----
598+
databend.com 12 1200 135
599+
databend.cn 11 1100 134
600+
601+
statement ok
602+
UPDATE test_update_udf_1 SET b = url_len(url),c = url_len_mul_100(url), a = length(url) + 123;
603+
604+
query TIII
605+
SELECT * FROM test_update_udf_1;
606+
----
607+
databend.com 135 12 1200
608+
databend.cn 134 11 1100
609+
610+
611+
567612
query I
568613
SELECT url_len('databend.com');
569614
----

tests/udf/udf_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ def json_access(data: Any, key: str) -> Any:
190190
def url_len(key: str) -> int:
191191
return len(key)
192192

193+
@udf(input_types=["VARCHAR"], result_type="BIGINT")
194+
def url_len_mul_100(key: str) -> int:
195+
return len(key) * 100
193196

194197
@udf(input_types=["ARRAY(VARIANT)"], result_type="VARIANT")
195198
def json_concat(list: List[Any]) -> Any:
@@ -449,6 +452,7 @@ def embedding_4(s: str):
449452
udf_server.add_function(wait)
450453
udf_server.add_function(wait_concurrent)
451454
udf_server.add_function(url_len)
455+
udf_server.add_function(url_len_mul_100)
452456
udf_server.add_function(check_headers)
453457
udf_server.add_function(embedding_4)
454458

0 commit comments

Comments
 (0)