Skip to content

Commit d22e293

Browse files
brayanjulsalambqstommyshu
authored
Include data types in logical plans of inferred prepare statements (apache#16019)
* draft commit to rolledback changes on function naming and include prepare clause on the infer types tests * include data types in plan when it is not included in the prepare statement * fix: prepare statement error * Update datafusion/sql/src/statement.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * remove infer types from prepare statement the infer data type changes in statement will be introduced in a new PR * fix to show correct output message * include data types on logical plans of prepare statements without explicit type declaration * fix using clippy sugestions * explicitly get the data types using the placeholder id to avoid sorting * Restore the original tests too * update set data type routine to be more rust idiomatic Co-authored-by: Tommy shu <qstommyshu@gmail.com> * update set datatype routine * fix formatting in sql_integration --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> Co-authored-by: Tommy shu <qstommyshu@gmail.com>
1 parent 5669500 commit d22e293

File tree

3 files changed

+109
-8
lines changed

3 files changed

+109
-8
lines changed

datafusion/sql/src/statement.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
696696
statement,
697697
} => {
698698
// Convert parser data types to DataFusion data types
699-
let data_types: Vec<DataType> = data_types
699+
let mut data_types: Vec<DataType> = data_types
700700
.into_iter()
701701
.map(|t| self.convert_data_type(&t))
702702
.collect::<Result<_>>()?;
@@ -710,6 +710,19 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
710710
*statement,
711711
&mut planner_context,
712712
)?;
713+
714+
if data_types.is_empty() {
715+
let map_types = plan.get_parameter_types()?;
716+
let param_types: Vec<_> = (1..=map_types.len())
717+
.filter_map(|i| {
718+
let key = format!("${i}");
719+
map_types.get(&key).and_then(|opt| opt.clone())
720+
})
721+
.collect();
722+
data_types.extend(param_types.iter().cloned());
723+
planner_context.with_prepare_param_data_types(param_types);
724+
}
725+
713726
Ok(LogicalPlan::Statement(PlanStatement::Prepare(Prepare {
714727
name: ident_to_string(&name),
715728
data_types,

datafusion/sql/tests/sql_integration.rs

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4650,7 +4650,7 @@ fn test_prepare_statement_infer_types_from_join() {
46504650
assert_snapshot!(
46514651
plan,
46524652
@r#"
4653-
Prepare: "my_plan" []
4653+
Prepare: "my_plan" [Int32]
46544654
Projection: person.id, orders.order_id
46554655
Inner Join: Filter: person.id = orders.customer_id AND person.age = $1
46564656
TableScan: person
@@ -4661,6 +4661,20 @@ fn test_prepare_statement_infer_types_from_join() {
46614661
let actual_types = plan.get_parameter_types().unwrap();
46624662
let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]);
46634663
assert_eq!(actual_types, expected_types);
4664+
4665+
// replace params with values
4666+
let param_values = vec![ScalarValue::Int32(Some(10))];
4667+
let plan_with_params = plan.with_param_values(param_values).unwrap();
4668+
4669+
assert_snapshot!(
4670+
plan_with_params,
4671+
@r"
4672+
Projection: person.id, orders.order_id
4673+
Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10)
4674+
TableScan: person
4675+
TableScan: orders
4676+
"
4677+
);
46644678
}
46654679

46664680
#[test]
@@ -4701,7 +4715,7 @@ fn test_prepare_statement_infer_types_from_predicate() {
47014715
assert_snapshot!(
47024716
plan,
47034717
@r#"
4704-
Prepare: "my_plan" []
4718+
Prepare: "my_plan" [Int32]
47054719
Projection: person.id, person.age
47064720
Filter: person.age = $1
47074721
TableScan: person
@@ -4711,6 +4725,19 @@ fn test_prepare_statement_infer_types_from_predicate() {
47114725
let actual_types = plan.get_parameter_types().unwrap();
47124726
let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]);
47134727
assert_eq!(actual_types, expected_types);
4728+
4729+
// replace params with values
4730+
let param_values = vec![ScalarValue::Int32(Some(10))];
4731+
let plan_with_params = plan.with_param_values(param_values).unwrap();
4732+
4733+
assert_snapshot!(
4734+
plan_with_params,
4735+
@r"
4736+
Projection: person.id, person.age
4737+
Filter: person.age = Int32(10)
4738+
TableScan: person
4739+
"
4740+
);
47144741
}
47154742

47164743
#[test]
@@ -4756,7 +4783,7 @@ fn test_prepare_statement_infer_types_from_between_predicate() {
47564783
assert_snapshot!(
47574784
plan,
47584785
@r#"
4759-
Prepare: "my_plan" []
4786+
Prepare: "my_plan" [Int32, Int32]
47604787
Projection: person.id, person.age
47614788
Filter: person.age BETWEEN $1 AND $2
47624789
TableScan: person
@@ -4769,6 +4796,19 @@ fn test_prepare_statement_infer_types_from_between_predicate() {
47694796
("$2".to_string(), Some(DataType::Int32)),
47704797
]);
47714798
assert_eq!(actual_types, expected_types);
4799+
4800+
// replace params with values
4801+
let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))];
4802+
let plan_with_params = plan.with_param_values(param_values).unwrap();
4803+
4804+
assert_snapshot!(
4805+
plan_with_params,
4806+
@r"
4807+
Projection: person.id, person.age
4808+
Filter: person.age BETWEEN Int32(10) AND Int32(30)
4809+
TableScan: person
4810+
"
4811+
);
47724812
}
47734813

47744814
#[test]
@@ -4821,7 +4861,7 @@ fn test_prepare_statement_infer_types_subquery() {
48214861
assert_snapshot!(
48224862
plan,
48234863
@r#"
4824-
Prepare: "my_plan" []
4864+
Prepare: "my_plan" [UInt32]
48254865
Projection: person.id, person.age
48264866
Filter: person.age = (<subquery>)
48274867
Subquery:
@@ -4836,6 +4876,24 @@ fn test_prepare_statement_infer_types_subquery() {
48364876
let actual_types = plan.get_parameter_types().unwrap();
48374877
let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]);
48384878
assert_eq!(actual_types, expected_types);
4879+
4880+
// replace params with values
4881+
let param_values = vec![ScalarValue::UInt32(Some(10))];
4882+
let plan_with_params = plan.with_param_values(param_values).unwrap();
4883+
4884+
assert_snapshot!(
4885+
plan_with_params,
4886+
@r"
4887+
Projection: person.id, person.age
4888+
Filter: person.age = (<subquery>)
4889+
Subquery:
4890+
Projection: max(person.age)
4891+
Aggregate: groupBy=[[]], aggr=[[max(person.age)]]
4892+
Filter: person.id = UInt32(10)
4893+
TableScan: person
4894+
TableScan: person
4895+
"
4896+
);
48394897
}
48404898

48414899
#[test]
@@ -4883,7 +4941,7 @@ fn test_prepare_statement_update_infer() {
48834941
assert_snapshot!(
48844942
plan,
48854943
@r#"
4886-
Prepare: "my_plan" []
4944+
Prepare: "my_plan" [Int32, UInt32]
48874945
Dml: op=[Update] table=[person]
48884946
Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀
48894947
Filter: person.id = $2
@@ -4897,6 +4955,20 @@ fn test_prepare_statement_update_infer() {
48974955
("$2".to_string(), Some(DataType::UInt32)),
48984956
]);
48994957
assert_eq!(actual_types, expected_types);
4958+
4959+
// replace params with values
4960+
let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))];
4961+
let plan_with_params = plan.with_param_values(param_values).unwrap();
4962+
4963+
assert_snapshot!(
4964+
plan_with_params,
4965+
@r"
4966+
Dml: op=[Update] table=[person]
4967+
Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀
4968+
Filter: person.id = UInt32(1)
4969+
TableScan: person
4970+
"
4971+
);
49004972
}
49014973

49024974
#[test]
@@ -4944,7 +5016,7 @@ fn test_prepare_statement_insert_infer() {
49445016
assert_snapshot!(
49455017
plan,
49465018
@r#"
4947-
Prepare: "my_plan" []
5019+
Prepare: "my_plan" [UInt32, Utf8, Utf8]
49485020
Dml: op=[Insert Into] table=[person]
49495021
Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀
49505022
Values: ($1, $2, $3)
@@ -4958,6 +5030,22 @@ fn test_prepare_statement_insert_infer() {
49585030
("$3".to_string(), Some(DataType::Utf8)),
49595031
]);
49605032
assert_eq!(actual_types, expected_types);
5033+
5034+
// replace params with values
5035+
let param_values = vec![
5036+
ScalarValue::UInt32(Some(1)),
5037+
ScalarValue::from("Alan"),
5038+
ScalarValue::from("Turing"),
5039+
];
5040+
let plan_with_params = plan.with_param_values(param_values).unwrap();
5041+
assert_snapshot!(
5042+
plan_with_params,
5043+
@r#"
5044+
Dml: op=[Insert Into] table=[person]
5045+
Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀
5046+
Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3)
5047+
"#
5048+
);
49615049
}
49625050

49635051
#[test]

datafusion/sqllogictest/test_files/prepare.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ DEALLOCATE my_plan
9292
statement ok
9393
PREPARE my_plan AS SELECT * FROM person WHERE id < $1;
9494

95-
statement error No value found for placeholder with id \$1
95+
statement error Prepared statement 'my_plan' expects 1 parameters, but 0 provided
9696
EXECUTE my_plan
9797

9898
statement ok

0 commit comments

Comments
 (0)