From 02b506dbcde185b4c74773d96eb38c988746cc43 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 15 Jun 2024 16:24:16 -0400 Subject: [PATCH 1/5] Improve examples of analyzer and optimizer rules --- datafusion-examples/README.md | 10 +- .../{rewrite_expr.rs => analyzer_rule.rs} | 0 .../examples/optimizer_rule.rs | 255 ++++++++++++++++++ datafusion-examples/examples/pruning.rs | 3 + datafusion-examples/examples/sql_planning.rs | 255 ++++++++++++++++++ 5 files changed, 518 insertions(+), 5 deletions(-) rename datafusion-examples/examples/{rewrite_expr.rs => analyzer_rule.rs} (100%) create mode 100644 datafusion-examples/examples/optimizer_rule.rs create mode 100644 datafusion-examples/examples/sql_planning.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index c34f706adb82..7498210f36a6 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -61,20 +61,20 @@ cargo run --example csv_sql - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es -- ['parquet_index.rs'](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries +- [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files -- ['parquet_exec_visitor.rs'](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution +- [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution - [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from Datafusion `Expr` and `LogicalPlan` -- [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics +- [`pruning.rs`](examples/parquet_sql.rs): Use a custom catalog and a PruningPredicate to prune files with a predicate and statistics - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`regexp.rs`](examples/regexp.rs): Examples of using regular expression functions -- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) -- [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a custom SQL dialect on top of `DFParser` +- [`sql_dialect.rs`](examples/sql_dialect.rs): Implementing a custom SQL dialect on top of `DFParser` +- [`sql_planning.rs`](examples/rewrite_expr.rs): Create LogicalPlans (only) from sql strings - [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function - [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using to_timestamp functions diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/analyzer_rule.rs similarity index 100% rename from datafusion-examples/examples/rewrite_expr.rs rename to datafusion-examples/examples/analyzer_rule.rs diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs new file mode 100644 index 000000000000..d8965888eab6 --- /dev/null +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -0,0 +1,255 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{plan_err, Result, ScalarValue}; +use datafusion_expr::{ + AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, +}; +use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; +use datafusion_optimizer::optimizer::Optimizer; +use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule}; +use datafusion_sql::planner::{ContextProvider, SqlToRel}; +use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion_sql::sqlparser::parser::Parser; +use datafusion_sql::TableReference; +use std::any::Any; +use std::sync::Arc; + +pub fn main() -> Result<()> { + // produce a logical plan using the datafusion-sql crate + let dialect = PostgreSqlDialect {}; + let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; + let statements = Parser::parse_sql(&dialect, sql)?; + + // produce a logical plan using the datafusion-sql crate + let context_provider = MyContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context_provider); + let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; + println!( + "Unoptimized Logical Plan:\n\n{}\n", + logical_plan.display_indent() + ); + + // run the analyzer with our custom rule + let config = OptimizerContext::default().with_skip_failing_rules(false); + let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); + let analyzed_plan = + analyzer.execute_and_check(logical_plan, config.options(), |_, _| {})?; + println!( + "Analyzed Logical Plan:\n\n{}\n", + analyzed_plan.display_indent() + ); + + // then run the optimizer with our custom rule + let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); + let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; + println!( + "Optimized Logical Plan:\n\n{}\n", + optimized_plan.display_indent() + ); + + Ok(()) +} + +fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { + println!( + "After applying rule '{}':\n{}\n", + rule.name(), + plan.display_indent() + ) +} + +/// An example analyzer rule that changes Int64 literals to UInt64 +struct MyAnalyzerRule {} + +impl AnalyzerRule for MyAnalyzerRule { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + Self::analyze_plan(plan) + } + + fn name(&self) -> &str { + "my_analyzer_rule" + } +} + +impl MyAnalyzerRule { + fn analyze_plan(plan: LogicalPlan) -> Result { + plan.transform(|plan| { + Ok(match plan { + LogicalPlan::Filter(filter) => { + let predicate = Self::analyze_expr(filter.predicate.clone())?; + Transformed::yes(LogicalPlan::Filter(Filter::try_new( + predicate, + filter.input, + )?)) + } + _ => Transformed::no(plan), + }) + }) + .data() + } + + fn analyze_expr(expr: Expr) -> Result { + expr.transform(|expr| { + // closure is invoked for all sub expressions + Ok(match expr { + Expr::Literal(ScalarValue::Int64(i)) => { + // transform to UInt64 + Transformed::yes(Expr::Literal(ScalarValue::UInt64( + i.map(|i| i as u64), + ))) + } + _ => Transformed::no(expr), + }) + }) + .data() + } +} + +/// An example optimizer rule that rewrite BETWEEN expression to binary compare expressions +struct MyOptimizerRule {} + +impl OptimizerRule for MyOptimizerRule { + fn name(&self) -> &str { + "my_optimizer_rule" + } + + fn try_optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // recurse down and optimize children first + let optimized_plan = utils::optimize_children(self, plan, config)?; + match optimized_plan { + Some(LogicalPlan::Filter(filter)) => { + let predicate = my_rewrite(filter.predicate.clone())?; + Ok(Some(LogicalPlan::Filter(Filter::try_new( + predicate, + filter.input, + )?))) + } + Some(optimized_plan) => Ok(Some(optimized_plan)), + None => match plan { + LogicalPlan::Filter(filter) => { + let predicate = my_rewrite(filter.predicate.clone())?; + Ok(Some(LogicalPlan::Filter(Filter::try_new( + predicate, + filter.input.clone(), + )?))) + } + _ => Ok(None), + }, + } + } +} + +/// use rewrite_expr to modify the expression tree. +fn my_rewrite(expr: Expr) -> Result { + expr.transform(|expr| { + // closure is invoked for all sub expressions + Ok(match expr { + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { + // unbox + let expr: Expr = *expr; + let low: Expr = *low; + let high: Expr = *high; + if negated { + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) + } else { + Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) + } + } + _ => Transformed::no(expr), + }) + }) + .data() +} + +#[derive(Default)] +struct MyContextProvider { + options: ConfigOptions, +} + +impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + if name.table() == "person" { + Ok(Arc::new(MyTableSource { + schema: Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::UInt8, false), + ])), + })) + } else { + plan_err!("table not found") + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + None + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + None + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + + fn options(&self) -> &ConfigOptions { + &self.options + } + + fn udf_names(&self) -> Vec { + Vec::new() + } + + fn udaf_names(&self) -> Vec { + Vec::new() + } + + fn udwf_names(&self) -> Vec { + Vec::new() + } +} + +struct MyTableSource { + schema: SchemaRef, +} + +impl TableSource for MyTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/pruning.rs index 3fa35049a8da..9b5b87d6e0d2 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/pruning.rs @@ -33,6 +33,9 @@ use std::sync::Arc; /// quickly eliminate entire files / partitions / row groups of data from /// consideration using statistical information from a catalog or other /// metadata. +/// +/// This example uses a user defined catalog to supply information. See `parquet_index.rs` for +/// an example that extracts the necessary information from Parquet metadata. #[tokio::main] async fn main() { // In this example, we'll use the PruningPredicate to determine if diff --git a/datafusion-examples/examples/sql_planning.rs b/datafusion-examples/examples/sql_planning.rs new file mode 100644 index 000000000000..d8965888eab6 --- /dev/null +++ b/datafusion-examples/examples/sql_planning.rs @@ -0,0 +1,255 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{plan_err, Result, ScalarValue}; +use datafusion_expr::{ + AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, +}; +use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; +use datafusion_optimizer::optimizer::Optimizer; +use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule}; +use datafusion_sql::planner::{ContextProvider, SqlToRel}; +use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion_sql::sqlparser::parser::Parser; +use datafusion_sql::TableReference; +use std::any::Any; +use std::sync::Arc; + +pub fn main() -> Result<()> { + // produce a logical plan using the datafusion-sql crate + let dialect = PostgreSqlDialect {}; + let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; + let statements = Parser::parse_sql(&dialect, sql)?; + + // produce a logical plan using the datafusion-sql crate + let context_provider = MyContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context_provider); + let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; + println!( + "Unoptimized Logical Plan:\n\n{}\n", + logical_plan.display_indent() + ); + + // run the analyzer with our custom rule + let config = OptimizerContext::default().with_skip_failing_rules(false); + let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); + let analyzed_plan = + analyzer.execute_and_check(logical_plan, config.options(), |_, _| {})?; + println!( + "Analyzed Logical Plan:\n\n{}\n", + analyzed_plan.display_indent() + ); + + // then run the optimizer with our custom rule + let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); + let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; + println!( + "Optimized Logical Plan:\n\n{}\n", + optimized_plan.display_indent() + ); + + Ok(()) +} + +fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { + println!( + "After applying rule '{}':\n{}\n", + rule.name(), + plan.display_indent() + ) +} + +/// An example analyzer rule that changes Int64 literals to UInt64 +struct MyAnalyzerRule {} + +impl AnalyzerRule for MyAnalyzerRule { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + Self::analyze_plan(plan) + } + + fn name(&self) -> &str { + "my_analyzer_rule" + } +} + +impl MyAnalyzerRule { + fn analyze_plan(plan: LogicalPlan) -> Result { + plan.transform(|plan| { + Ok(match plan { + LogicalPlan::Filter(filter) => { + let predicate = Self::analyze_expr(filter.predicate.clone())?; + Transformed::yes(LogicalPlan::Filter(Filter::try_new( + predicate, + filter.input, + )?)) + } + _ => Transformed::no(plan), + }) + }) + .data() + } + + fn analyze_expr(expr: Expr) -> Result { + expr.transform(|expr| { + // closure is invoked for all sub expressions + Ok(match expr { + Expr::Literal(ScalarValue::Int64(i)) => { + // transform to UInt64 + Transformed::yes(Expr::Literal(ScalarValue::UInt64( + i.map(|i| i as u64), + ))) + } + _ => Transformed::no(expr), + }) + }) + .data() + } +} + +/// An example optimizer rule that rewrite BETWEEN expression to binary compare expressions +struct MyOptimizerRule {} + +impl OptimizerRule for MyOptimizerRule { + fn name(&self) -> &str { + "my_optimizer_rule" + } + + fn try_optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // recurse down and optimize children first + let optimized_plan = utils::optimize_children(self, plan, config)?; + match optimized_plan { + Some(LogicalPlan::Filter(filter)) => { + let predicate = my_rewrite(filter.predicate.clone())?; + Ok(Some(LogicalPlan::Filter(Filter::try_new( + predicate, + filter.input, + )?))) + } + Some(optimized_plan) => Ok(Some(optimized_plan)), + None => match plan { + LogicalPlan::Filter(filter) => { + let predicate = my_rewrite(filter.predicate.clone())?; + Ok(Some(LogicalPlan::Filter(Filter::try_new( + predicate, + filter.input.clone(), + )?))) + } + _ => Ok(None), + }, + } + } +} + +/// use rewrite_expr to modify the expression tree. +fn my_rewrite(expr: Expr) -> Result { + expr.transform(|expr| { + // closure is invoked for all sub expressions + Ok(match expr { + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { + // unbox + let expr: Expr = *expr; + let low: Expr = *low; + let high: Expr = *high; + if negated { + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) + } else { + Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) + } + } + _ => Transformed::no(expr), + }) + }) + .data() +} + +#[derive(Default)] +struct MyContextProvider { + options: ConfigOptions, +} + +impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + if name.table() == "person" { + Ok(Arc::new(MyTableSource { + schema: Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::UInt8, false), + ])), + })) + } else { + plan_err!("table not found") + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + None + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + None + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + + fn options(&self) -> &ConfigOptions { + &self.options + } + + fn udf_names(&self) -> Vec { + Vec::new() + } + + fn udaf_names(&self) -> Vec { + Vec::new() + } + + fn udwf_names(&self) -> Vec { + Vec::new() + } +} + +struct MyTableSource { + schema: SchemaRef, +} + +impl TableSource for MyTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} From 4cad93906a9692776b236668afe6a7178662a7fe Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 15 Jun 2024 16:51:02 -0400 Subject: [PATCH 2/5] Work on sql planning example --- datafusion-examples/README.md | 2 +- datafusion-examples/examples/sql_frontend.rs | 171 +++++++++++++ datafusion-examples/examples/sql_planning.rs | 255 ------------------- datafusion/expr/src/table_source.rs | 19 +- 4 files changed, 183 insertions(+), 264 deletions(-) create mode 100644 datafusion-examples/examples/sql_frontend.rs delete mode 100644 datafusion-examples/examples/sql_planning.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 7498210f36a6..add0e1b89466 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -73,8 +73,8 @@ cargo run --example csv_sql - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) +- [`sql_frontend.rs`](examples/sql_frontend.rs): Create LogicalPlans (only) from sql strings - [`sql_dialect.rs`](examples/sql_dialect.rs): Implementing a custom SQL dialect on top of `DFParser` -- [`sql_planning.rs`](examples/rewrite_expr.rs): Create LogicalPlans (only) from sql strings - [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function - [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using to_timestamp functions diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_frontend.rs new file mode 100644 index 000000000000..e5a5caafc5b2 --- /dev/null +++ b/datafusion-examples/examples/sql_frontend.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; +use datafusion_sql::planner::{ContextProvider, SqlToRel}; +use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion_sql::sqlparser::parser::Parser; +use datafusion_sql::TableReference; +use std::any::Any; +use std::sync::Arc; + +/// This example shows how to use DataFusion's SQL planner to parse SQL text and +/// build `LogicalPlan`s without executing them. +/// +/// For example, if you need a SQL planner and optimizer like Apache Calcite, but +/// do not want a Java runtime dependency for some reason, you can use +/// DataFusion as a SQL frontend. +pub fn main() -> Result<()> { + // Normally, users interact with DataFusion via SessionContext. However, + // using SessionContext requires depending on the full `datafusion` crate. + // + // In this example, we demonstrate how to use the lower level APIs directly, + // which only requires the `datafusion-sql` dependencies. + + // First, we parse the SQL string. Note that we use the DataFusion + // Parser, which wraps the `sqlparser-rs` SQL parser and adds DataFusion + // specific syntax such as `CREATE EXTERNAL TABLE` + let dialect = PostgreSqlDialect {}; + let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; + let statements = Parser::parse_sql(&dialect, sql)?; + + // Now, use DataFusion's SQL planner, called `SqlToRel` to create a + // `LogicalPlan` from the parsed statement + // + // To invoke SqlToRel we must provide it schema and function information + // via an object that implements the `ContextProvider` trait + let context_provider = MyContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context_provider); + let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; + println!( + "Unoptimized Logical Plan:\n\n{}\n", + logical_plan.display_indent() + ); + + // Projection: person.name, person.age + // Filter: person.age BETWEEN Int64(21) AND Int64(32) + // TableScan: person + + // The initial plan is a mechanical translation from the parsed SQL and + // often can not run. In this example, `person.age` is actually a different + // data type (Int32) than the values to which it is compared to which are + // Int64. Most execution engines, including DataFusion's, will fail if you + // provide such a plan. + // + // To prepare it to run, we must apply type coercion to align types, and + // check for other semantic errors. In DataFusion this is done by a + // component called the Analyzer. + let config = OptimizerContext::default().with_skip_failing_rules(false); + let analyzed_plan =Analyzer::new() + .execute_and_check(logical_plan, config.options(), observe_analyzer)?; + println!( + "Analyzed Logical Plan:\n\n{}\n", + analyzed_plan.display_indent() + ); + + // Finally we must invoke the DataFusion optimizer to improve the plans + // performance by applying various rewrite rules. + let optimized_plan = Optimizer::new().optimize(analyzed_plan, &config, observe_optimizer)?; + println!( + "Optimized Logical Plan:\n\n{}\n", + optimized_plan.display_indent() + ); + + Ok(()) +} + +// Both the optimizer and the analyzer take a callback, called an "observer" +// that is invoked after each pass. We do not do anything with these callbacks +// in this example + +fn observe_analyzer(_plan: &LogicalPlan, _rule: &dyn AnalyzerRule) { +} +fn observe_optimizer(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) { +} + + +/// Implements the `ContextProvider` trait required to plan SQL +#[derive(Default)] +struct MyContextProvider { + options: ConfigOptions, +} + +impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + if name.table() == "person" { + Ok(Arc::new(MyTableSource { + schema: Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::UInt8, false), + ])), + })) + } else { + plan_err!("table not found") + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + None + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + None + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + + fn options(&self) -> &ConfigOptions { + &self.options + } + + fn udf_names(&self) -> Vec { + Vec::new() + } + + fn udaf_names(&self) -> Vec { + Vec::new() + } + + fn udwf_names(&self) -> Vec { + Vec::new() + } +} + +/// TableSource is the part of TableProvider needed for creating a LogicalPlan. +struct MyTableSource { + schema: SchemaRef, +} + +impl TableSource for MyTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/datafusion-examples/examples/sql_planning.rs b/datafusion-examples/examples/sql_planning.rs deleted file mode 100644 index d8965888eab6..000000000000 --- a/datafusion-examples/examples/sql_planning.rs +++ /dev/null @@ -1,255 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, Result, ScalarValue}; -use datafusion_expr::{ - AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, -}; -use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; -use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule}; -use datafusion_sql::planner::{ContextProvider, SqlToRel}; -use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; -use datafusion_sql::sqlparser::parser::Parser; -use datafusion_sql::TableReference; -use std::any::Any; -use std::sync::Arc; - -pub fn main() -> Result<()> { - // produce a logical plan using the datafusion-sql crate - let dialect = PostgreSqlDialect {}; - let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; - let statements = Parser::parse_sql(&dialect, sql)?; - - // produce a logical plan using the datafusion-sql crate - let context_provider = MyContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context_provider); - let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; - println!( - "Unoptimized Logical Plan:\n\n{}\n", - logical_plan.display_indent() - ); - - // run the analyzer with our custom rule - let config = OptimizerContext::default().with_skip_failing_rules(false); - let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); - let analyzed_plan = - analyzer.execute_and_check(logical_plan, config.options(), |_, _| {})?; - println!( - "Analyzed Logical Plan:\n\n{}\n", - analyzed_plan.display_indent() - ); - - // then run the optimizer with our custom rule - let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; - println!( - "Optimized Logical Plan:\n\n{}\n", - optimized_plan.display_indent() - ); - - Ok(()) -} - -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { - println!( - "After applying rule '{}':\n{}\n", - rule.name(), - plan.display_indent() - ) -} - -/// An example analyzer rule that changes Int64 literals to UInt64 -struct MyAnalyzerRule {} - -impl AnalyzerRule for MyAnalyzerRule { - fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - Self::analyze_plan(plan) - } - - fn name(&self) -> &str { - "my_analyzer_rule" - } -} - -impl MyAnalyzerRule { - fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(|plan| { - Ok(match plan { - LogicalPlan::Filter(filter) => { - let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?)) - } - _ => Transformed::no(plan), - }) - }) - .data() - } - - fn analyze_expr(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Literal(ScalarValue::Int64(i)) => { - // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) - } - _ => Transformed::no(expr), - }) - }) - .data() - } -} - -/// An example optimizer rule that rewrite BETWEEN expression to binary compare expressions -struct MyOptimizerRule {} - -impl OptimizerRule for MyOptimizerRule { - fn name(&self) -> &str { - "my_optimizer_rule" - } - - fn try_optimize( - &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - // recurse down and optimize children first - let optimized_plan = utils::optimize_children(self, plan, config)?; - match optimized_plan { - Some(LogicalPlan::Filter(filter)) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?))) - } - Some(optimized_plan) => Ok(Some(optimized_plan)), - None => match plan { - LogicalPlan::Filter(filter) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input.clone(), - )?))) - } - _ => Ok(None), - }, - } - } -} - -/// use rewrite_expr to modify the expression tree. -fn my_rewrite(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - // unbox - let expr: Expr = *expr; - let low: Expr = *low; - let high: Expr = *high; - if negated { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } else { - Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) - } - } - _ => Transformed::no(expr), - }) - }) - .data() -} - -#[derive(Default)] -struct MyContextProvider { - options: ConfigOptions, -} - -impl ContextProvider for MyContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - if name.table() == "person" { - Ok(Arc::new(MyTableSource { - schema: Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("age", DataType::UInt8, false), - ])), - })) - } else { - plan_err!("table not found") - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn udf_names(&self) -> Vec { - Vec::new() - } - - fn udaf_names(&self) -> Vec { - Vec::new() - } - - fn udwf_names(&self) -> Vec { - Vec::new() - } -} - -struct MyTableSource { - schema: SchemaRef, -} - -impl TableSource for MyTableSource { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index 72ed51f44415..2de3cc923315 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -71,14 +71,17 @@ impl std::fmt::Display for TableType { } } -/// The TableSource trait is used during logical query planning and optimizations and -/// provides access to schema information and filter push-down capabilities. This trait -/// provides a subset of the functionality of the TableProvider trait in the core -/// datafusion crate. The TableProvider trait provides additional capabilities needed for -/// physical query execution (such as the ability to perform a scan). The reason for -/// having two separate traits is to avoid having the logical plan code be dependent -/// on the DataFusion execution engine. Other projects may want to use DataFusion's -/// logical plans and have their own execution engine. +/// Access schema information and filter push-down capabilities. +/// +/// The TableSource trait is used during logical query planning and +/// optimizations and provides a subset of the functionality of the +/// `TableProvider` trait in the (core) `datafusion` crate. The `TableProvider` +/// trait provides additional capabilities needed for physical query execution +/// (such as the ability to perform a scan). +/// +/// The reason for having two separate traits is to avoid having the logical +/// plan code be dependent on the DataFusion execution engine. Some projects use +/// DataFusion's logical plans and have their own execution engine. pub trait TableSource: Sync + Send { fn as_any(&self) -> &dyn Any; From a1bc4b71983c1dad665d43408980726272ff3bb6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 15 Jun 2024 17:04:48 -0400 Subject: [PATCH 3/5] Tweak sql_example --- datafusion-examples/examples/sql_frontend.rs | 67 +++++++++++++++----- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_frontend.rs index e5a5caafc5b2..595159f13951 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_frontend.rs @@ -18,8 +18,13 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; -use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; -use datafusion_optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; +use datafusion_expr::{ + AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, TableSource, + WindowUDF, +}; +use datafusion_optimizer::{ + Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule, +}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; use datafusion_sql::sqlparser::parser::Parser; @@ -44,7 +49,7 @@ pub fn main() -> Result<()> { // Parser, which wraps the `sqlparser-rs` SQL parser and adds DataFusion // specific syntax such as `CREATE EXTERNAL TABLE` let dialect = PostgreSqlDialect {}; - let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; + let sql = "SELECT name FROM person WHERE age BETWEEN 21 AND 32"; let statements = Parser::parse_sql(&dialect, sql)?; // Now, use DataFusion's SQL planner, called `SqlToRel` to create a @@ -60,7 +65,7 @@ pub fn main() -> Result<()> { logical_plan.display_indent() ); - // Projection: person.name, person.age + // Projection: person.name // Filter: person.age BETWEEN Int64(21) AND Int64(32) // TableScan: person @@ -74,33 +79,52 @@ pub fn main() -> Result<()> { // check for other semantic errors. In DataFusion this is done by a // component called the Analyzer. let config = OptimizerContext::default().with_skip_failing_rules(false); - let analyzed_plan =Analyzer::new() - .execute_and_check(logical_plan, config.options(), observe_analyzer)?; + let analyzed_plan = Analyzer::new().execute_and_check( + logical_plan, + config.options(), + observe_analyzer, + )?; println!( "Analyzed Logical Plan:\n\n{}\n", analyzed_plan.display_indent() ); - // Finally we must invoke the DataFusion optimizer to improve the plans - // performance by applying various rewrite rules. - let optimized_plan = Optimizer::new().optimize(analyzed_plan, &config, observe_optimizer)?; + // Projection: person.name + // Filter: CAST(person.age AS Int64) BETWEEN Int64(21) AND Int64(32) + // TableScan: person + + // As we can see, the Analyzer added a CAST so the types are the same + // (Int64). However, this plan is not as efficient as it could be, as it + // will require casting *each row* of the input to UInt64 before comparison + // to 21 and 32. To optimize this query's performance, it is better to cast + // the constants once at plan time to Int32. + // + // Query optimization is handled in DataFusion by a component called the + // Optimizer, which we now invoke + let optimized_plan = + Optimizer::new().optimize(analyzed_plan, &config, observe_optimizer)?; println!( "Optimized Logical Plan:\n\n{}\n", optimized_plan.display_indent() ); - Ok(()) -} + // TableScan: person projection=[name], full_filters=[person.age >= UInt8(21), person.age <= UInt8(32)] -// Both the optimizer and the analyzer take a callback, called an "observer" -// that is invoked after each pass. We do not do anything with these callbacks -// in this example + // The optimizer did several things to this plan: + // 1. Removed casts from person.age as we described above + // 2. Converted BETWEEN to two single columns inequalities (which are typically faster to execute) + // 3. Pushed the projection of `name` down to the scan (so the scan only returns that column) + // 4. Pushed the filter all the way down into the scan -fn observe_analyzer(_plan: &LogicalPlan, _rule: &dyn AnalyzerRule) { -} -fn observe_optimizer(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) { + Ok(()) } +// Note that both the optimizer and the analyzer take a callback, called an +// "observer" that is invoked after each pass. We do not do anything with these +// callbacks in this example + +fn observe_analyzer(_plan: &LogicalPlan, _rule: &dyn AnalyzerRule) {} +fn observe_optimizer(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} /// Implements the `ContextProvider` trait required to plan SQL #[derive(Default)] @@ -168,4 +192,13 @@ impl TableSource for MyTableSource { fn schema(&self) -> SchemaRef { self.schema.clone() } + + // For this example, we report to the DataFusion optimizer that + // this provider can apply filters during the scan + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) + } } From 8e6984472b70177dacb542ef1630232d348e3e0c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 15 Jun 2024 17:47:17 -0400 Subject: [PATCH 4/5] Complete analyzer example --- datafusion-examples/README.md | 1 + datafusion-examples/examples/analyzer_rule.rs | 288 +++++------------- datafusion/core/src/execution/context/mod.rs | 7 + datafusion/core/src/execution/mod.rs | 3 + 4 files changed, 91 insertions(+), 208 deletions(-) diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index add0e1b89466..747b9c2d5f19 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -45,6 +45,7 @@ cargo run --example csv_sql - [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) +- [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule to change a query's semantics - [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file - [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog - [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs index d8965888eab6..e89f68f0302e 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -15,73 +15,85 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use datafusion::prelude::SessionContext; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, Result, ScalarValue}; -use datafusion_expr::{ - AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, -}; -use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; -use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule}; -use datafusion_sql::planner::{ContextProvider, SqlToRel}; -use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; -use datafusion_sql::sqlparser::parser::Parser; -use datafusion_sql::TableReference; -use std::any::Any; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{lit, Expr, LogicalPlan}; +use datafusion_optimizer::analyzer::AnalyzerRule; use std::sync::Arc; -pub fn main() -> Result<()> { - // produce a logical plan using the datafusion-sql crate - let dialect = PostgreSqlDialect {}; +/// This example demonstrates how to add your own [`AnalyzerRule`] +/// to DataFusion. +/// +/// [`AnalyzerRule`]s transform [`LogicalPlan`]s prior to the rest of the +/// DataFusion optimization process, and are allowed to change the plan's +/// semantics (e.g. output types). +/// +#[tokio::main] +pub async fn main() -> Result<()> { + // DataFusion includes several built in AnalyzerRules for tasks such as type + // coercion. To modify the list of rules, we must use the lower level + // SessionState API + let state = SessionContext::new().state(); + let state = state.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); + + // To plan and run queries with the new rule, create a SessionContext with + // the modified SessionState + let ctx = SessionContext::from(state); + ctx.register_batch("person", person_batch())?; + + // Plan a SQL statement as normal let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; - let statements = Parser::parse_sql(&dialect, sql)?; - - // produce a logical plan using the datafusion-sql crate - let context_provider = MyContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context_provider); - let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; - println!( - "Unoptimized Logical Plan:\n\n{}\n", - logical_plan.display_indent() - ); - - // run the analyzer with our custom rule - let config = OptimizerContext::default().with_skip_failing_rules(false); - let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); - let analyzed_plan = - analyzer.execute_and_check(logical_plan, config.options(), |_, _| {})?; - println!( - "Analyzed Logical Plan:\n\n{}\n", - analyzed_plan.display_indent() - ); - - // then run the optimizer with our custom rule - let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; - println!( - "Optimized Logical Plan:\n\n{}\n", - optimized_plan.display_indent() - ); + let plan = ctx.sql(sql).await?.into_optimized_plan()?; - Ok(()) -} + println!("Logical Plan:\n\n{}\n", plan.display_indent()); + + // We can see the effect of our rewrite on the output plan. Even though the + // input query was between 21 and 32, the plan is between 31 and 42 + + // Filter: person.age >= Int32(31) AND person.age <= Int32(42) + // TableScan: person projection=[name, age] + + ctx.sql(sql).await?.show().await?; + + // And the output verifies the predicates have been changed + + // +-------+-----+ + // | name | age | + // +-------+-----+ + // | Oleks | 33 | + // +-------+-----+ -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { - println!( - "After applying rule '{}':\n{}\n", - rule.name(), - plan.display_indent() - ) + Ok(()) } -/// An example analyzer rule that changes Int64 literals to UInt64 +/// An example analyzer rule that changes adds 10 to all Int64 literals in the plan struct MyAnalyzerRule {} impl AnalyzerRule for MyAnalyzerRule { fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - Self::analyze_plan(plan) + // use the TreeNode API to recursively walk the LogicalPlan tree + // and all of its children (inputs) + plan.transform(|plan| { + // This closure is called for each LogicalPlan node + plan.map_expressions(|expr| { + // This closure is called for all expressions in the current plan + // + // For example, given a plan like `SELECT a + b, 5 + 10` + // + // The closure would be called twice, once for `a + b` and once for `5 + 10` + self.rewrite_expr(expr) + }) + }) + // the result of calling transform is a `Transformed` structure that + // contains a flag signalling if any rewrite took place as well as + // if the recursion stopped early. + // + // This example does not need either of that information, so simply + // extract the LogicalPlan "data" + .data() } fn name(&self) -> &str { @@ -90,166 +102,26 @@ impl AnalyzerRule for MyAnalyzerRule { } impl MyAnalyzerRule { - fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(|plan| { - Ok(match plan { - LogicalPlan::Filter(filter) => { - let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?)) - } - _ => Transformed::no(plan), - }) - }) - .data() - } - - fn analyze_expr(expr: Expr) -> Result { + /// rewrites an idividual expression + fn rewrite_expr(&self, expr: Expr) -> Result> { expr.transform(|expr| { // closure is invoked for all sub expressions - Ok(match expr { - Expr::Literal(ScalarValue::Int64(i)) => { - // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) - } - _ => Transformed::no(expr), - }) - }) - .data() - } -} - -/// An example optimizer rule that rewrite BETWEEN expression to binary compare expressions -struct MyOptimizerRule {} - -impl OptimizerRule for MyOptimizerRule { - fn name(&self) -> &str { - "my_optimizer_rule" - } - - fn try_optimize( - &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - // recurse down and optimize children first - let optimized_plan = utils::optimize_children(self, plan, config)?; - match optimized_plan { - Some(LogicalPlan::Filter(filter)) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?))) - } - Some(optimized_plan) => Ok(Some(optimized_plan)), - None => match plan { - LogicalPlan::Filter(filter) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input.clone(), - )?))) - } - _ => Ok(None), - }, - } - } -} -/// use rewrite_expr to modify the expression tree. -fn my_rewrite(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - // unbox - let expr: Expr = *expr; - let low: Expr = *low; - let high: Expr = *high; - if negated { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } else { - Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) - } + // Transformed is used to transfer the "was this rewritten" + // information back up the stack. + if let Expr::Literal(ScalarValue::Int64(Some(i))) = expr { + Ok(Transformed::yes(lit(i + 10))) + } else { + Ok(Transformed::no(expr)) } - _ => Transformed::no(expr), }) - }) - .data() -} - -#[derive(Default)] -struct MyContextProvider { - options: ConfigOptions, -} - -impl ContextProvider for MyContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - if name.table() == "person" { - Ok(Arc::new(MyTableSource { - schema: Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("age", DataType::UInt8, false), - ])), - })) - } else { - plan_err!("table not found") - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn udf_names(&self) -> Vec { - Vec::new() - } - - fn udaf_names(&self) -> Vec { - Vec::new() } - - fn udwf_names(&self) -> Vec { - Vec::new() - } -} - -struct MyTableSource { - schema: SchemaRef, } -impl TableSource for MyTableSource { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } +/// Return a RecordBatch with made up date +fn person_batch() -> RecordBatch { + let name: ArrayRef = + Arc::new(StringArray::from_iter_values(["Andy", "Andrew", "Oleks"])); + let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22, 33])); + RecordBatch::try_from_iter(vec![("name", name), ("age", age)]).unwrap() } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 6fa83d3d931e..5e23fda80f26 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1323,6 +1323,13 @@ impl FunctionRegistry for SessionContext { } } +/// Create a SessionContext from a SessionState +impl From for SessionContext { + fn from(state: SessionState) -> Self { + Self::new_with_state(state) + } +} + /// Create a new task context instance from SessionContext impl From<&SessionContext> for TaskContext { fn from(session: &SessionContext) -> Self { diff --git a/datafusion/core/src/execution/mod.rs b/datafusion/core/src/execution/mod.rs index ac02c7317256..75357cbdc2b9 100644 --- a/datafusion/core/src/execution/mod.rs +++ b/datafusion/core/src/execution/mod.rs @@ -1,3 +1,6 @@ +// DataFusion includes several built in AnalyzerRules for tasks such as +// type coercion. Specify that our custom AnalyzerRule should run after +// all the built in rules // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information From 4040295b879a0834d05faa574e046bbbc567053d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 15 Jun 2024 18:24:21 -0400 Subject: [PATCH 5/5] Add optimizer rule --- datafusion-examples/README.md | 1 + datafusion-examples/examples/analyzer_rule.rs | 3 +- .../examples/optimizer_rule.rs | 346 ++++++++---------- 3 files changed, 159 insertions(+), 191 deletions(-) diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 747b9c2d5f19..fc15ff662e6b 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -62,6 +62,7 @@ cargo run --example csv_sql - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es +- [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to use a special operator - [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs index e89f68f0302e..8935da965343 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -31,6 +31,7 @@ use std::sync::Arc; /// DataFusion optimization process, and are allowed to change the plan's /// semantics (e.g. output types). /// +/// See [optimizer_rule.rs] for an example of a optimizer rule #[tokio::main] pub async fn main() -> Result<()> { // DataFusion includes several built in AnalyzerRules for tasks such as type @@ -118,7 +119,7 @@ impl MyAnalyzerRule { } } -/// Return a RecordBatch with made up date +/// Return a RecordBatch with made up data fn person_batch() -> RecordBatch { let name: ArrayRef = Arc::new(StringArray::from_iter_values(["Andy", "Andrew", "Oleks"])); diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index d8965888eab6..7b349bb5a005 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -15,241 +15,207 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, Result, ScalarValue}; +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow_schema::DataType; +use datafusion::prelude::SessionContext; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::{ - AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, + BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, }; -use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; -use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule}; -use datafusion_sql::planner::{ContextProvider, SqlToRel}; -use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; -use datafusion_sql::sqlparser::parser::Parser; -use datafusion_sql::TableReference; +use datafusion_optimizer::optimizer::ApplyOrder; +use datafusion_optimizer::{OptimizerConfig, OptimizerRule}; use std::any::Any; use std::sync::Arc; -pub fn main() -> Result<()> { - // produce a logical plan using the datafusion-sql crate - let dialect = PostgreSqlDialect {}; - let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; - let statements = Parser::parse_sql(&dialect, sql)?; - - // produce a logical plan using the datafusion-sql crate - let context_provider = MyContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context_provider); - let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; - println!( - "Unoptimized Logical Plan:\n\n{}\n", - logical_plan.display_indent() - ); - - // run the analyzer with our custom rule - let config = OptimizerContext::default().with_skip_failing_rules(false); - let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); - let analyzed_plan = - analyzer.execute_and_check(logical_plan, config.options(), |_, _| {})?; - println!( - "Analyzed Logical Plan:\n\n{}\n", - analyzed_plan.display_indent() - ); - - // then run the optimizer with our custom rule - let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; - println!( - "Optimized Logical Plan:\n\n{}\n", - optimized_plan.display_indent() - ); +/// This example demonstrates how to add your own [`OptimizerRule`] +/// to DataFusion. +/// +/// [`OptimizerRule`]s transform [`LogicalPlan`]s into an equivalent (but +/// hopefully faster) form. +/// +/// See [analyzer_rule.rs] for an example of AnalyzerRules, which are for +/// changing plan semantics. +#[tokio::main] +pub async fn main() -> Result<()> { + // DataFusion includes many built in OptimizerRules for tasks such as outer + // to inner join conversion and constant folding. To modify the list of + // optimizer rules, we must use the lower level SessionState API + let state = SessionContext::new().state(); + let state = state.add_optimizer_rule(Arc::new(MyOptimizerRule {})); + + // To plan and run queries with the new rule, create a SessionContext with + // the modified SessionState + let ctx = SessionContext::from(state); + ctx.register_batch("person", person_batch())?; + + // Plan a SQL statement as normal + let sql = "SELECT * FROM person WHERE age = 22"; + let plan = ctx.sql(sql).await?.into_optimized_plan()?; + + println!("Logical Plan:\n\n{}\n", plan.display_indent()); + + // We can see the effect of our rewrite on the output plan that the filter + // has been rewritten to my_eq + + // Filter: my_eq(person.age, Int32(22)) + // TableScan: person projection=[name, age] + + ctx.sql(sql).await?.show().await?; + + // And the output verifies the predicates have been changed (as the my_eq + // always returns true) + + // +--------+-----+ + // | name | age | + // +--------+-----+ + // | Andy | 11 | + // | Andrew | 22 | + // | Oleks | 33 | + // +--------+-----+ + + // however we can see the rule doesn't trigger for queries with not equal + // predicates + ctx.sql("SELECT * FROM person WHERE age <> 22") + .await? + .show() + .await?; + + // +-------+-----+ + // | name | age | + // +-------+-----+ + // | Andy | 11 | + // | Oleks | 33 | + // +-------+-----+ Ok(()) } -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { - println!( - "After applying rule '{}':\n{}\n", - rule.name(), - plan.display_indent() - ) -} - -/// An example analyzer rule that changes Int64 literals to UInt64 -struct MyAnalyzerRule {} - -impl AnalyzerRule for MyAnalyzerRule { - fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - Self::analyze_plan(plan) - } +/// An example optimizer rule that looks for col = and replaces it with +/// a user defined function +struct MyOptimizerRule {} +impl OptimizerRule for MyOptimizerRule { fn name(&self) -> &str { - "my_analyzer_rule" + "my_optimizer_rule" } -} -impl MyAnalyzerRule { - fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(|plan| { - Ok(match plan { - LogicalPlan::Filter(filter) => { - let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?)) - } - _ => Transformed::no(plan), - }) - }) - .data() + // New OptimizerRules should use the "rewrite" api as it is more efficient + fn supports_rewrite(&self) -> bool { + true } - fn analyze_expr(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Literal(ScalarValue::Int64(i)) => { - // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) - } - _ => Transformed::no(expr), - }) - }) - .data() + /// Ask the optimizer to handle the plan recursion. `rewrite` will be called + /// on each plan node. + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) } -} - -/// An example optimizer rule that rewrite BETWEEN expression to binary compare expressions -struct MyOptimizerRule {} -impl OptimizerRule for MyOptimizerRule { - fn name(&self) -> &str { - "my_optimizer_rule" + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + plan.map_expressions(|expr| { + // This closure is called for all expressions in the current plan + // + // For example, given a plan like `SELECT a + b, 5 + 10` + // + // The closure would be called twice, once for `a + b` and once for `5 + 10` + self.rewrite_expr(expr) + }) } fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { - // recurse down and optimize children first - let optimized_plan = utils::optimize_children(self, plan, config)?; - match optimized_plan { - Some(LogicalPlan::Filter(filter)) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?))) - } - Some(optimized_plan) => Ok(Some(optimized_plan)), - None => match plan { - LogicalPlan::Filter(filter) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input.clone(), - )?))) - } - _ => Ok(None), - }, - } + // since this rule uses the rewrite API, return an error if the old API is called + return internal_err!("Should have called rewrite"); } } -/// use rewrite_expr to modify the expression tree. -fn my_rewrite(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - // unbox - let expr: Expr = *expr; - let low: Expr = *low; - let high: Expr = *high; - if negated { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } else { - Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) +impl MyOptimizerRule { + /// Rewrites an Expr replacing all ` = ` expressions with + /// a call to my_eq udf + fn rewrite_expr(&self, expr: Expr) -> Result> { + // do a bottom up rewrite of the expression tree + expr.transform_up(|expr| { + // Closure called for each sub tree + match expr { + Expr::BinaryExpr(binary_expr) if is_binary_eq(&binary_expr) => { + // destruture the expression + let BinaryExpr { left, op: _, right } = binary_expr; + // rewrite to `my_eq(left, right)` + let udf = ScalarUDF::new_from_impl(MyEq::new()); + let call = udf.call(vec![*left, *right]); + Ok(Transformed::yes(call)) } + _ => return Ok(Transformed::no(expr)), } - _ => Transformed::no(expr), }) - }) - .data() + // Note that the TreeNode API handles propagating the transformed flag + // and errors up the call chain + } } -#[derive(Default)] -struct MyContextProvider { - options: ConfigOptions, +/// return true of the expression is an equality expression for a literal or +/// column reference +fn is_binary_eq(binary_expr: &BinaryExpr) -> bool { + binary_expr.op == Operator::Eq + && is_lit_or_col(binary_expr.left.as_ref()) + && is_lit_or_col(binary_expr.right.as_ref()) } -impl ContextProvider for MyContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - if name.table() == "person" { - Ok(Arc::new(MyTableSource { - schema: Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("age", DataType::UInt8, false), - ])), - })) - } else { - plan_err!("table not found") - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } +/// Return true if the expression is a literal or column reference +fn is_lit_or_col(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } +/// A simple user defined filter function +#[derive(Debug, Clone)] +struct MyEq { + signature: Signature, +} - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None +impl MyEq { + fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Stable), + } } +} - fn get_window_meta(&self, _name: &str) -> Option> { - None +impl ScalarUDFImpl for MyEq { + fn as_any(&self) -> &dyn Any { + self } - fn options(&self) -> &ConfigOptions { - &self.options + fn name(&self) -> &str { + "my_eq" } - fn udf_names(&self) -> Vec { - Vec::new() + fn signature(&self) -> &Signature { + &self.signature } - fn udaf_names(&self) -> Vec { - Vec::new() + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) } - fn udwf_names(&self) -> Vec { - Vec::new() + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + // this example simply returns "true" which is not what a real + // implementation would do. + return Ok(ColumnarValue::Scalar(ScalarValue::from(true))); } } -struct MyTableSource { - schema: SchemaRef, -} - -impl TableSource for MyTableSource { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } +/// Return a RecordBatch with made up data +fn person_batch() -> RecordBatch { + let name: ArrayRef = + Arc::new(StringArray::from_iter_values(["Andy", "Andrew", "Oleks"])); + let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22, 33])); + RecordBatch::try_from_iter(vec![("name", name), ("age", age)]).unwrap() }