diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 425e1fb60..2fc65f989 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -344,12 +344,14 @@ impl fmt::Display for ObjectName { #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub enum ObjectNamePart { Identifier(Ident), + Function(ObjectNamePartFunction), } impl ObjectNamePart { pub fn as_ident(&self) -> Option<&Ident> { match self { ObjectNamePart::Identifier(ident) => Some(ident), + ObjectNamePart::Function(_) => None, } } } @@ -358,10 +360,30 @@ impl fmt::Display for ObjectNamePart { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ObjectNamePart::Identifier(ident) => write!(f, "{ident}"), + ObjectNamePart::Function(func) => write!(f, "{func}"), } } } +/// An object name part that consists of a function that dynamically +/// constructs identifiers. +/// +/// - [Snowflake](https://docs.snowflake.com/en/sql-reference/identifier-literal) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct ObjectNamePartFunction { + pub name: Ident, + pub args: Vec, +} + +impl fmt::Display for ObjectNamePartFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}(", self.name)?; + write!(f, "{})", display_comma_separated(&self.args)) + } +} + /// Represents an Array Expression, either /// `ARRAY[..]`, or `[..]` #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 144de5923..a1b2e4e0d 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -1671,6 +1671,10 @@ impl Spanned for ObjectNamePart { fn span(&self) -> Span { match self { ObjectNamePart::Identifier(ident) => ident.span, + ObjectNamePart::Function(func) => func + .name + .span + .union(&union_spans(func.args.iter().map(|i| i.span()))), } } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index bc3c55554..88d63f277 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -49,7 +49,7 @@ pub use self::postgresql::PostgreSqlDialect; pub use self::redshift::RedshiftSqlDialect; pub use self::snowflake::SnowflakeDialect; pub use self::sqlite::SQLiteDialect; -use crate::ast::{ColumnOption, Expr, GranteesType, Statement}; +use crate::ast::{ColumnOption, Expr, GranteesType, Ident, Statement}; pub use crate::keywords; use crate::keywords::Keyword; use crate::parser::{Parser, ParserError}; @@ -1076,6 +1076,15 @@ pub trait Dialect: Debug + Any { fn supports_comma_separated_drop_column_list(&self) -> bool { false } + + /// Returns true if the dialect considers the specified ident as a function + /// that returns an identifier. Typically used to generate identifiers + /// programmatically. + /// + /// - [Snowflake](https://docs.snowflake.com/en/sql-reference/identifier-literal) + fn is_identifier_generating_function_name(&self, _ident: &Ident) -> bool { + false + } } /// This represents the operators for which precedence must be defined diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 212cf217d..12ccd647f 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -367,6 +367,15 @@ impl Dialect for SnowflakeDialect { fn supports_comma_separated_drop_column_list(&self) -> bool { true } + + fn is_identifier_generating_function_name(&self, ident: &Ident) -> bool { + ident.quote_style.is_none() && ident.value.to_lowercase() == "identifier" + } + + // For example: `SELECT IDENTIFIER('alias1').* FROM tbl AS alias1` + fn supports_select_expr_star(&self) -> bool { + true + } } fn parse_file_staging_command(kw: Keyword, parser: &mut Parser) -> Result { diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 839d36459..c4d92fb27 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10353,70 +10353,89 @@ impl<'a> Parser<'a> { } } - /// Parse a possibly qualified, possibly quoted identifier, optionally allowing for wildcards, + /// Parse a possibly qualified, possibly quoted identifier, e.g. + /// `foo` or `myschema."table" + /// + /// The `in_table_clause` parameter indicates whether the object name is a table in a FROM, JOIN, + /// or similar table clause. Currently, this is used only to support unquoted hyphenated identifiers + /// in this context on BigQuery. + pub fn parse_object_name(&mut self, in_table_clause: bool) -> Result { + self.parse_object_name_inner(in_table_clause, false) + } + + /// Parse a possibly qualified, possibly quoted identifier, e.g. + /// `foo` or `myschema."table" + /// + /// The `in_table_clause` parameter indicates whether the object name is a table in a FROM, JOIN, + /// or similar table clause. Currently, this is used only to support unquoted hyphenated identifiers + /// in this context on BigQuery. + /// + /// The `allow_wildcards` parameter indicates whether to allow for wildcards in the object name /// e.g. *, *.*, `foo`.*, or "foo"."bar" - fn parse_object_name_with_wildcards( + fn parse_object_name_inner( &mut self, in_table_clause: bool, allow_wildcards: bool, ) -> Result { - let mut idents = vec![]; - + let mut parts = vec![]; if dialect_of!(self is BigQueryDialect) && in_table_clause { loop { let (ident, end_with_period) = self.parse_unquoted_hyphenated_identifier()?; - idents.push(ident); + parts.push(ObjectNamePart::Identifier(ident)); if !self.consume_token(&Token::Period) && !end_with_period { break; } } } else { loop { - let ident = if allow_wildcards && self.peek_token().token == Token::Mul { + if allow_wildcards && self.peek_token().token == Token::Mul { let span = self.next_token().span; - Ident { + parts.push(ObjectNamePart::Identifier(Ident { value: Token::Mul.to_string(), quote_style: None, span, + })); + } else if dialect_of!(self is BigQueryDialect) && in_table_clause { + let (ident, end_with_period) = self.parse_unquoted_hyphenated_identifier()?; + parts.push(ObjectNamePart::Identifier(ident)); + if !self.consume_token(&Token::Period) && !end_with_period { + break; } + } else if self.dialect.supports_object_name_double_dot_notation() + && parts.len() == 1 + && matches!(self.peek_token().token, Token::Period) + { + // Empty string here means default schema + parts.push(ObjectNamePart::Identifier(Ident::new(""))); } else { - if self.dialect.supports_object_name_double_dot_notation() - && idents.len() == 1 - && self.consume_token(&Token::Period) - { - // Empty string here means default schema - idents.push(Ident::new("")); - } - self.parse_identifier()? - }; - idents.push(ident); + let ident = self.parse_identifier()?; + let part = if self.dialect.is_identifier_generating_function_name(&ident) { + self.expect_token(&Token::LParen)?; + let args: Vec = + self.parse_comma_separated0(Self::parse_function_args, Token::RParen)?; + self.expect_token(&Token::RParen)?; + ObjectNamePart::Function(ObjectNamePartFunction { name: ident, args }) + } else { + ObjectNamePart::Identifier(ident) + }; + parts.push(part); + } + if !self.consume_token(&Token::Period) { break; } } } - Ok(ObjectName::from(idents)) - } - - /// Parse a possibly qualified, possibly quoted identifier, e.g. - /// `foo` or `myschema."table" - /// - /// The `in_table_clause` parameter indicates whether the object name is a table in a FROM, JOIN, - /// or similar table clause. Currently, this is used only to support unquoted hyphenated identifiers - /// in this context on BigQuery. - pub fn parse_object_name(&mut self, in_table_clause: bool) -> Result { - let ObjectName(mut idents) = - self.parse_object_name_with_wildcards(in_table_clause, false)?; // BigQuery accepts any number of quoted identifiers of a table name. // https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#quoted_identifiers if dialect_of!(self is BigQueryDialect) - && idents.iter().any(|part| { + && parts.iter().any(|part| { part.as_ident() .is_some_and(|ident| ident.value.contains('.')) }) { - idents = idents + parts = parts .into_iter() .flat_map(|part| match part.as_ident() { Some(ident) => ident @@ -10435,7 +10454,7 @@ impl<'a> Parser<'a> { .collect() } - Ok(ObjectName(idents)) + Ok(ObjectName(parts)) } /// Parse identifiers @@ -13938,25 +13957,25 @@ impl<'a> Parser<'a> { schemas: self.parse_comma_separated(|p| p.parse_object_name(false))?, }) } else if self.parse_keywords(&[Keyword::RESOURCE, Keyword::MONITOR]) { - Some(GrantObjects::ResourceMonitors(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::ResourceMonitors( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else if self.parse_keywords(&[Keyword::COMPUTE, Keyword::POOL]) { - Some(GrantObjects::ComputePools(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::ComputePools( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else if self.parse_keywords(&[Keyword::FAILOVER, Keyword::GROUP]) { - Some(GrantObjects::FailoverGroup(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::FailoverGroup( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else if self.parse_keywords(&[Keyword::REPLICATION, Keyword::GROUP]) { - Some(GrantObjects::ReplicationGroup(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::ReplicationGroup( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else if self.parse_keywords(&[Keyword::EXTERNAL, Keyword::VOLUME]) { - Some(GrantObjects::ExternalVolumes(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::ExternalVolumes( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else { let object_type = self.parse_one_of_keywords(&[ Keyword::SEQUENCE, @@ -13973,7 +13992,7 @@ impl<'a> Parser<'a> { Keyword::CONNECTION, ]); let objects = - self.parse_comma_separated(|p| p.parse_object_name_with_wildcards(false, true)); + self.parse_comma_separated(|p| p.parse_object_name_inner(false, true)); match object_type { Some(Keyword::DATABASE) => Some(GrantObjects::Databases(objects?)), Some(Keyword::SCHEMA) => Some(GrantObjects::Schemas(objects?)), diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index ed9bb704d..01c3c1607 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1232,7 +1232,6 @@ fn parse_select_expr_star() { "SELECT 2. * 3 FROM T", ); dialects.verified_only_select("SELECT myfunc().* FROM T"); - dialects.verified_only_select("SELECT myfunc().* EXCEPT (foo) FROM T"); // Invalid let res = dialects.parse_sql_statements("SELECT foo.*.* FROM T"); @@ -1240,6 +1239,11 @@ fn parse_select_expr_star() { ParserError::ParserError("Expected: end of statement, found: .".to_string()), res.unwrap_err() ); + + let dialects = all_dialects_where(|d| { + d.supports_select_expr_star() && d.supports_select_wildcard_except() + }); + dialects.verified_only_select("SELECT myfunc().* EXCEPT (foo) FROM T"); } #[test] diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index e7393d3f8..a164db5bb 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -4232,3 +4232,122 @@ fn test_snowflake_create_view_with_composite_policy_name() { r#"CREATE VIEW X (COL WITH MASKING POLICY foo.bar.baz) AS SELECT * FROM Y"#; snowflake().verified_stmt(create_view_with_tag); } + +#[test] +fn test_snowflake_identifier_function() { + // Using IDENTIFIER to reference a column + match &snowflake() + .verified_only_select("SELECT identifier('email') FROM customers") + .projection[0] + { + SelectItem::UnnamedExpr(Expr::Function(Function { name, args, .. })) => { + assert_eq!(*name, ObjectName::from(vec![Ident::new("identifier")])); + assert_eq!( + *args, + FunctionArguments::List(FunctionArgumentList { + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("email".to_string()).into() + )))], + clauses: vec![], + duplicate_treatment: None + }) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference a case-sensitive column + match &snowflake() + .verified_only_select(r#"SELECT identifier('"Email"') FROM customers"#) + .projection[0] + { + SelectItem::UnnamedExpr(Expr::Function(Function { name, args, .. })) => { + assert_eq!(*name, ObjectName::from(vec![Ident::new("identifier")])); + assert_eq!( + *args, + FunctionArguments::List(FunctionArgumentList { + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("\"Email\"".to_string()).into() + )))], + clauses: vec![], + duplicate_treatment: None + }) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference an alias of a table + match &snowflake() + .verified_only_select("SELECT identifier('alias1').* FROM tbl AS alias1") + .projection[0] + { + SelectItem::QualifiedWildcard( + SelectItemQualifiedWildcardKind::Expr(Expr::Function(Function { name, args, .. })), + _, + ) => { + assert_eq!(*name, ObjectName::from(vec![Ident::new("identifier")])); + assert_eq!( + *args, + FunctionArguments::List(FunctionArgumentList { + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("alias1".to_string()).into() + )))], + clauses: vec![], + duplicate_treatment: None + }) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference a database + match snowflake().verified_stmt("CREATE DATABASE IDENTIFIER('tbl')") { + Statement::CreateDatabase { db_name, .. } => { + assert_eq!( + db_name, + ObjectName(vec![ObjectNamePart::Function(ObjectNamePartFunction { + name: Ident::new("IDENTIFIER"), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("tbl".to_string()).into() + )))] + })]) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference a schema + match snowflake().verified_stmt("CREATE SCHEMA IDENTIFIER('db1.sc1')") { + Statement::CreateSchema { schema_name, .. } => { + assert_eq!( + schema_name, + SchemaName::Simple(ObjectName(vec![ObjectNamePart::Function( + ObjectNamePartFunction { + name: Ident::new("IDENTIFIER"), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("db1.sc1".to_string()).into() + )))] + } + )])) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference a table + match snowflake().verified_stmt("CREATE TABLE IDENTIFIER('tbl') (id INT)") { + Statement::CreateTable(CreateTable { name, .. }) => { + assert_eq!( + name, + ObjectName(vec![ObjectNamePart::Function(ObjectNamePartFunction { + name: Ident::new("IDENTIFIER"), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("tbl".to_string()).into() + )))] + })]) + ); + } + _ => unreachable!(), + } +}