From 0ed924b6f1d53046f25af87111b4e8e9138d8a97 Mon Sep 17 00:00:00 2001 From: Liam Bao Date: Sat, 14 Jun 2025 22:45:38 -0400 Subject: [PATCH] datafusion-cli: Use correct S3 region if it is not specified --- datafusion-cli/src/catalog.rs | 2 +- datafusion-cli/src/exec.rs | 56 ++++++++---- datafusion-cli/src/object_storage.rs | 89 +++++++++++++++++-- datafusion-cli/tests/cli_integration.rs | 29 ++++++ .../snapshots/aws_region_auto_resolve.snap | 28 ++++++ 5 files changed, 181 insertions(+), 23 deletions(-) create mode 100644 datafusion-cli/tests/snapshots/aws_region_auto_resolve.snap diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index 3298b7deaeba..fd83b52de299 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -200,6 +200,7 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { table_url.scheme(), url, &state.default_table_options(), + false, ) .await?; state.runtime_env().register_object_store(url, store); @@ -229,7 +230,6 @@ pub fn substitute_tilde(cur: String) -> String { } #[cfg(test)] mod tests { - use super::*; use datafusion::catalog::SchemaProvider; diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 3c2a6e68bbe1..4336382b5e6c 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -26,28 +26,28 @@ use crate::{ object_storage::get_object_store, print_options::{MaxRows, PrintOptions}, }; -use futures::StreamExt; -use std::collections::HashMap; -use std::fs::File; -use std::io::prelude::*; -use std::io::BufReader; - use datafusion::common::instant::Instant; use datafusion::common::{plan_datafusion_err, plan_err}; use datafusion::config::ConfigFileType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; use datafusion::physical_plan::execution_plan::EmissionType; +use datafusion::physical_plan::spill::get_record_batch_memory_size; use datafusion::physical_plan::{execute_stream, ExecutionPlanProperties}; use datafusion::sql::parser::{DFParser, Statement}; -use datafusion::sql::sqlparser::dialect::dialect_from_str; - -use datafusion::execution::memory_pool::MemoryConsumer; -use datafusion::physical_plan::spill::get_record_batch_memory_size; use datafusion::sql::sqlparser; +use datafusion::sql::sqlparser::dialect::dialect_from_str; +use futures::StreamExt; +use log::warn; +use object_store::Error::Generic; use rustyline::error::ReadlineError; use rustyline::Editor; +use std::collections::HashMap; +use std::fs::File; +use std::io::prelude::*; +use std::io::BufReader; use tokio::signal; /// run and execute SQL statements and commands, against a context with the given print options @@ -231,10 +231,24 @@ pub(super) async fn exec_and_print( let adjusted = AdjustedPrintOptions::new(print_options.clone()).with_statement(&statement); - let plan = create_plan(ctx, statement).await?; + // Only clone the statement if it's a CreateExternalTable + let statement_for_retry = matches!(&statement, Statement::CreateExternalTable(_)) + .then(|| statement.clone()); + + let plan = create_plan(ctx, statement, false).await?; let adjusted = adjusted.with_plan(&plan); - let df = ctx.execute_logical_plan(plan).await?; + let df = match ctx.execute_logical_plan(plan).await { + Ok(df) => df, + Err(DataFusionError::ObjectStore(Generic { store, source: _ })) + if "S3".eq_ignore_ascii_case(store) && statement_for_retry.is_some() => + { + warn!("S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration."); + let plan = create_plan(ctx, statement_for_retry.unwrap(), true).await?; + ctx.execute_logical_plan(plan).await? + } + Err(e) => return Err(e), + }; let physical_plan = df.create_physical_plan().await?; // Track memory usage for the query result if it's bounded @@ -348,6 +362,7 @@ fn config_file_type_from_str(ext: &str) -> Option { async fn create_plan( ctx: &dyn CliSessionContext, statement: Statement, + resolve_region: bool, ) -> Result { let mut plan = ctx.session_state().statement_to_plan(statement).await?; @@ -362,6 +377,7 @@ async fn create_plan( &cmd.location, &cmd.options, format, + resolve_region, ) .await?; } @@ -374,6 +390,7 @@ async fn create_plan( ©_to.output_url, ©_to.options, format, + false, ) .await?; } @@ -412,6 +429,7 @@ pub(crate) async fn register_object_store_and_config_extensions( location: &String, options: &HashMap, format: Option, + resolve_region: bool, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -433,8 +451,14 @@ pub(crate) async fn register_object_store_and_config_extensions( table_options.alter_with_string_hash_map(options)?; // Retrieve the appropriate object store based on the scheme, URL, and modified table options - let store = - get_object_store(&ctx.session_state(), scheme, url, &table_options).await?; + let store = get_object_store( + &ctx.session_state(), + scheme, + url, + &table_options, + resolve_region, + ) + .await?; // Register the retrieved object store in the session context's runtime environment ctx.register_object_store(url, store); @@ -462,6 +486,7 @@ mod tests { &cmd.location, &cmd.options, format, + false, ) .await?; } else { @@ -488,6 +513,7 @@ mod tests { &cmd.output_url, &cmd.options, format, + false, ) .await?; } else { @@ -534,7 +560,7 @@ mod tests { let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { //Should not fail - let mut plan = create_plan(&ctx, statement).await?; + let mut plan = create_plan(&ctx, statement, false).await?; if let LogicalPlan::Copy(copy_to) = &mut plan { assert_eq!(copy_to.output_url, location); assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 01ba28609642..176dfdd4ceed 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -32,15 +32,28 @@ use aws_config::BehaviorVersion; use aws_credential_types::provider::error::CredentialsError; use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; use log::debug; -use object_store::aws::{AmazonS3Builder, AwsCredential}; +use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey, AwsCredential}; use object_store::gcp::GoogleCloudStorageBuilder; use object_store::http::HttpBuilder; use object_store::{ClientOptions, CredentialProvider, ObjectStore}; use url::Url; +#[cfg(not(test))] +use object_store::aws::resolve_bucket_region; + +// Provide a local mock when running tests so we don't make network calls +#[cfg(test)] +async fn resolve_bucket_region( + _bucket: &str, + _client_options: &ClientOptions, +) -> object_store::Result { + Ok("eu-central-1".to_string()) +} + pub async fn get_s3_object_store_builder( url: &Url, aws_options: &AwsOptions, + resolve_region: bool, ) -> Result { let AwsOptions { access_key_id, @@ -88,6 +101,16 @@ pub async fn get_s3_object_store_builder( builder = builder.with_region(region); } + // If the region is not set or auto_detect_region is true, resolve the region. + if builder + .get_config_value(&AmazonS3ConfigKey::Region) + .is_none() + || resolve_region + { + let region = resolve_bucket_region(bucket_name, &ClientOptions::new()).await?; + builder = builder.with_region(region); + } + if let Some(endpoint) = endpoint { // Make a nicer error if the user hasn't allowed http and the endpoint // is http as the default message is "URL scheme is not allowed" @@ -470,6 +493,7 @@ pub(crate) async fn get_object_store( scheme: &str, url: &Url, table_options: &TableOptions, + resolve_region: bool, ) -> Result, DataFusionError> { let store: Arc = match scheme { "s3" => { @@ -478,7 +502,8 @@ pub(crate) async fn get_object_store( "Given table options incompatible with the 's3' scheme" ); }; - let builder = get_s3_object_store_builder(url, options).await?; + let builder = + get_s3_object_store_builder(url, options, resolve_region).await?; Arc::new(builder.build()?) } "oss" => { @@ -557,12 +582,14 @@ mod tests { let table_options = get_table_options(&ctx, &sql).await; let aws_options = table_options.extensions.get::().unwrap(); let builder = - get_s3_object_store_builder(table_url.as_ref(), aws_options).await?; + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; // If the environment variables are set (as they are in CI) use them let expected_access_key_id = std::env::var("AWS_ACCESS_KEY_ID").ok(); let expected_secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY").ok(); - let expected_region = std::env::var("AWS_REGION").ok(); + let expected_region = Some( + std::env::var("AWS_REGION").unwrap_or_else(|_| "eu-central-1".to_string()), + ); let expected_endpoint = std::env::var("AWS_ENDPOINT").ok(); // get the actual configuration information, then assert_eq! @@ -624,7 +651,7 @@ mod tests { let table_options = get_table_options(&ctx, &sql).await; let aws_options = table_options.extensions.get::().unwrap(); let builder = - get_s3_object_store_builder(table_url.as_ref(), aws_options).await?; + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; // get the actual configuration information, then assert_eq! let config = [ (AmazonS3ConfigKey::AccessKeyId, access_key_id), @@ -667,7 +694,7 @@ mod tests { let table_options = get_table_options(&ctx, &sql).await; let aws_options = table_options.extensions.get::().unwrap(); - let err = get_s3_object_store_builder(table_url.as_ref(), aws_options) + let err = get_s3_object_store_builder(table_url.as_ref(), aws_options, false) .await .unwrap_err(); @@ -686,7 +713,55 @@ mod tests { let aws_options = table_options.extensions.get::().unwrap(); // ensure this isn't an error - get_s3_object_store_builder(table_url.as_ref(), aws_options).await?; + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; + + Ok(()) + } + + #[tokio::test] + async fn s3_object_store_builder_resolves_region_when_none_provided() -> Result<()> { + let expected_region = "eu-central-1"; + let location = "s3://test-bucket/path/file.parquet"; + + let table_url = ListingTableUrl::parse(location)?; + let aws_options = AwsOptions { + region: None, // No region specified - should auto-detect + ..Default::default() + }; + + let builder = + get_s3_object_store_builder(table_url.as_ref(), &aws_options, false).await?; + + // Verify that the region was auto-detected in test environment + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + Some(expected_region.to_string()) + ); + + Ok(()) + } + + #[tokio::test] + async fn s3_object_store_builder_overrides_region_when_resolve_region_enabled( + ) -> Result<()> { + let original_region = "us-east-1"; + let expected_region = "eu-central-1"; // This should be the auto-detected region + let location = "s3://test-bucket/path/file.parquet"; + + let table_url = ListingTableUrl::parse(location)?; + let aws_options = AwsOptions { + region: Some(original_region.to_string()), // Explicit region provided + ..Default::default() + }; + + let builder = + get_s3_object_store_builder(table_url.as_ref(), &aws_options, true).await?; + + // Verify that the region was overridden by auto-detection + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + Some(expected_region.to_string()) + ); Ok(()) } diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index fb2f08157f67..9de5010f1298 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -209,3 +209,32 @@ SELECT * FROM CARS limit 1; assert_cmd_snapshot!(cli().env_clear().pass_stdin(input)); } + +#[tokio::test] +async fn test_aws_region_auto_resolve() { + // Separate test is needed to pass aws as options in sql and not via env + + if env::var("TEST_STORAGE_INTEGRATION").is_err() { + eprintln!("Skipping external storages integration tests"); + return; + } + + let mut settings = make_settings(); + settings.add_filter(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", "[TIME]"); + let _bound = settings.bind_to_scope(); + + let bucket = "s3://clickhouse-public-datasets/hits_compatible/athena_partitioned/hits_1.parquet"; + let region = "us-east-1"; + + let input = format!( + r#"CREATE EXTERNAL TABLE hits +STORED AS PARQUET +LOCATION '{bucket}' +OPTIONS('aws.region' '{region}'); + +SELECT COUNT(*) FROM hits; +"# + ); + + assert_cmd_snapshot!(cli().env_clear().env("RUST_LOG", "warn").pass_stdin(input)); +} diff --git a/datafusion-cli/tests/snapshots/aws_region_auto_resolve.snap b/datafusion-cli/tests/snapshots/aws_region_auto_resolve.snap new file mode 100644 index 000000000000..9a6fb99bc823 --- /dev/null +++ b/datafusion-cli/tests/snapshots/aws_region_auto_resolve.snap @@ -0,0 +1,28 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: [] + env: + RUST_LOG: warn + stdin: "CREATE EXTERNAL TABLE hits\nSTORED AS PARQUET\nLOCATION 's3://clickhouse-public-datasets/hits_compatible/athena_partitioned/hits_1.parquet'\nOPTIONS('aws.region' 'us-east-1');\n\nSELECT COUNT(*) FROM hits;\n" +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] +0 row(s) fetched. +[ELAPSED] + ++----------+ +| count(*) | ++----------+ +| 1000000 | ++----------+ +1 row(s) fetched. +[ELAPSED] + +\q + +----- stderr ----- +[[TIME] WARN datafusion_cli::exec] S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration.