Skip to content

Commit 02bffc6

Browse files
committed
Add basic CREATE FUNCTION support for SQL Server
- in this dialect, functions can have statement(s) bodies like stored procedures (including `BEGIN`..`END`) - functions must end with `RETURN`, so a corresponding statement type is also introduced
1 parent 4a48729 commit 02bffc6

File tree

6 files changed

+284
-8
lines changed

6 files changed

+284
-8
lines changed

src/ast/ddl.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2272,6 +2272,12 @@ impl fmt::Display for CreateFunction {
22722272
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
22732273
write!(f, " AS {function_body}")?;
22742274
}
2275+
if let Some(CreateFunctionBody::MultiStatement(statements)) = &self.function_body {
2276+
write!(f, " AS")?;
2277+
write!(f, " BEGIN")?;
2278+
write!(f, " {}", display_separated(statements, "; "))?;
2279+
write!(f, " END")?;
2280+
}
22752281
Ok(())
22762282
}
22772283
}

src/ast/mod.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3614,6 +3614,7 @@ pub enum Statement {
36143614
/// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction)
36153615
/// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html)
36163616
/// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement)
3617+
/// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
36173618
CreateFunction(CreateFunction),
36183619
/// CREATE TRIGGER
36193620
///
@@ -4060,6 +4061,12 @@ pub enum Statement {
40604061
///
40614062
/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/print-transact-sql>
40624063
Print(PrintStatement),
4064+
/// ```sql
4065+
/// RETURN [ expression ]
4066+
/// ```
4067+
///
4068+
/// See [ReturnStatement]
4069+
Return(ReturnStatement),
40634070
}
40644071

40654072
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
@@ -5752,6 +5759,7 @@ impl fmt::Display for Statement {
57525759
Ok(())
57535760
}
57545761
Statement::Print(s) => write!(f, "{s}"),
5762+
Statement::Return(r) => write!(f, "{r}"),
57555763
Statement::List(command) => write!(f, "LIST {command}"),
57565764
Statement::Remove(command) => write!(f, "REMOVE {command}"),
57575765
}
@@ -8354,6 +8362,7 @@ impl fmt::Display for FunctionDeterminismSpecifier {
83548362
///
83558363
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83568364
/// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html
8365+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
83578366
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
83588367
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
83598368
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@@ -8382,6 +8391,22 @@ pub enum CreateFunctionBody {
83828391
///
83838392
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83848393
AsAfterOptions(Expr),
8394+
/// Function body with statements before the `RETURN` keyword.
8395+
///
8396+
/// Example:
8397+
/// ```sql
8398+
/// CREATE FUNCTION my_scalar_udf(a INT, b INT)
8399+
/// RETURNS INT
8400+
/// AS
8401+
/// BEGIN
8402+
/// DECLARE c INT;
8403+
/// SET c = a + b;
8404+
/// RETURN c;
8405+
/// END
8406+
/// ```
8407+
///
8408+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
8409+
MultiStatement(Vec<Statement>),
83858410
/// Function body expression using the 'RETURN' keyword.
83868411
///
83878412
/// Example:
@@ -9230,6 +9255,41 @@ impl fmt::Display for PrintStatement {
92309255
}
92319256
}
92329257

9258+
/// Return (MsSql)
9259+
///
9260+
/// for Functions:
9261+
/// RETURN scalar_expression
9262+
///
9263+
/// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql>
9264+
///
9265+
/// for Triggers:
9266+
/// RETURN
9267+
///
9268+
/// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql>
9269+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9270+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9271+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9272+
pub struct ReturnStatement {
9273+
pub value: Option<ReturnStatementValue>,
9274+
}
9275+
9276+
impl fmt::Display for ReturnStatement {
9277+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
9278+
match &self.value {
9279+
Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr),
9280+
None => write!(f, "RETURN"),
9281+
}
9282+
}
9283+
}
9284+
9285+
/// Variants of a `RETURN` statement
9286+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9287+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9288+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9289+
pub enum ReturnStatementValue {
9290+
Expr(Expr),
9291+
}
9292+
92339293
#[cfg(test)]
92349294
mod tests {
92359295
use super::*;

src/ast/spans.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@ impl Spanned for Statement {
520520
Statement::RenameTable { .. } => Span::empty(),
521521
Statement::RaisError { .. } => Span::empty(),
522522
Statement::Print { .. } => Span::empty(),
523+
Statement::Return { .. } => Span::empty(),
523524
Statement::List(..) | Statement::Remove(..) => Span::empty(),
524525
}
525526
}

src/parser/mod.rs

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -577,13 +577,7 @@ impl<'a> Parser<'a> {
577577
Keyword::GRANT => self.parse_grant(),
578578
Keyword::REVOKE => self.parse_revoke(),
579579
Keyword::START => self.parse_start_transaction(),
580-
// `BEGIN` is a nonstandard but common alias for the
581-
// standard `START TRANSACTION` statement. It is supported
582-
// by at least PostgreSQL and MySQL.
583580
Keyword::BEGIN => self.parse_begin(),
584-
// `END` is a nonstandard but common alias for the
585-
// standard `COMMIT TRANSACTION` statement. It is supported
586-
// by PostgreSQL.
587581
Keyword::END => self.parse_end(),
588582
Keyword::SAVEPOINT => self.parse_savepoint(),
589583
Keyword::RELEASE => self.parse_release(),
@@ -618,6 +612,7 @@ impl<'a> Parser<'a> {
618612
// `COMMENT` is snowflake specific https://docs.snowflake.com/en/sql-reference/sql/comment
619613
Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(),
620614
Keyword::PRINT => self.parse_print(),
615+
Keyword::RETURN => self.parse_return(),
621616
_ => self.expected("an SQL statement", next_token),
622617
},
623618
Token::LParen => {
@@ -4880,6 +4875,8 @@ impl<'a> Parser<'a> {
48804875
self.parse_create_macro(or_replace, temporary)
48814876
} else if dialect_of!(self is BigQueryDialect) {
48824877
self.parse_bigquery_create_function(or_replace, temporary)
4878+
} else if dialect_of!(self is MsSqlDialect) {
4879+
self.parse_mssql_create_function(or_replace, temporary)
48834880
} else {
48844881
self.prev_token();
48854882
self.expected("an object type after CREATE", self.peek_token())
@@ -5134,6 +5131,72 @@ impl<'a> Parser<'a> {
51345131
}))
51355132
}
51365133

5134+
/// Parse `CREATE FUNCTION` for [MsSql]
5135+
///
5136+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
5137+
fn parse_mssql_create_function(
5138+
&mut self,
5139+
or_replace: bool,
5140+
temporary: bool,
5141+
) -> Result<Statement, ParserError> {
5142+
let name = self.parse_object_name(false)?;
5143+
5144+
let parse_function_param =
5145+
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
5146+
let name = parser.parse_identifier()?;
5147+
let data_type = parser.parse_data_type()?;
5148+
Ok(OperateFunctionArg {
5149+
mode: None,
5150+
name: Some(name),
5151+
data_type,
5152+
default_expr: None,
5153+
})
5154+
};
5155+
self.expect_token(&Token::LParen)?;
5156+
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
5157+
self.expect_token(&Token::RParen)?;
5158+
5159+
let return_type = if self.parse_keyword(Keyword::RETURNS) {
5160+
Some(self.parse_data_type()?)
5161+
} else {
5162+
return parser_err!("Expected RETURNS keyword", self.peek_token().span.start);
5163+
};
5164+
5165+
self.expect_keyword_is(Keyword::AS)?;
5166+
self.expect_keyword_is(Keyword::BEGIN)?;
5167+
let mut result = self.parse_statements()?;
5168+
// note: `parse_statements` will consume the `END` token & produce a Commit statement...
5169+
if let Some(Statement::Commit {
5170+
chain,
5171+
end,
5172+
modifier,
5173+
}) = result.last()
5174+
{
5175+
if *chain == false && *end == true && *modifier == None {
5176+
result = result[..result.len() - 1].to_vec();
5177+
}
5178+
}
5179+
let function_body = Some(CreateFunctionBody::MultiStatement(result));
5180+
5181+
Ok(Statement::CreateFunction(CreateFunction {
5182+
or_replace,
5183+
temporary,
5184+
if_not_exists: false,
5185+
name,
5186+
args: Some(args),
5187+
return_type,
5188+
function_body,
5189+
language: None,
5190+
determinism_specifier: None,
5191+
options: None,
5192+
remote_connection: None,
5193+
using: None,
5194+
behavior: None,
5195+
called_on_null: None,
5196+
parallel: None,
5197+
}))
5198+
}
5199+
51375200
fn parse_function_arg(&mut self) -> Result<OperateFunctionArg, ParserError> {
51385201
let mode = if self.parse_keyword(Keyword::IN) {
51395202
Some(ArgMode::In)
@@ -15063,6 +15126,13 @@ impl<'a> Parser<'a> {
1506315126
message: Box::new(self.parse_expr()?),
1506415127
}))
1506515128
}
15129+
/// Parse [Statement::Return]
15130+
fn parse_return(&mut self) -> Result<Statement, ParserError> {
15131+
let expr = self.parse_expr()?;
15132+
Ok(Statement::Return(ReturnStatement {
15133+
value: Some(ReturnStatementValue::Expr(expr)),
15134+
}))
15135+
}
1506615136

1506715137
/// Consume the parser and return its underlying token buffer
1506815138
pub fn into_tokens(self) -> Vec<TokenWithSpan> {

tests/sqlparser_hive.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use sqlparser::ast::{
2525
Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr,
2626
OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value,
2727
};
28-
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
28+
use sqlparser::dialect::{AnsiDialect, GenericDialect, HiveDialect};
2929
use sqlparser::parser::ParserError;
3030
use sqlparser::test_utils::*;
3131

@@ -423,7 +423,7 @@ fn parse_create_function() {
423423
}
424424

425425
// Test error in dialect that doesn't support parsing CREATE FUNCTION
426-
let unsupported_dialects = TestedDialects::new(vec![Box::new(MsSqlDialect {})]);
426+
let unsupported_dialects = TestedDialects::new(vec![Box::new(AnsiDialect {})]);
427427

428428
assert_eq!(
429429
unsupported_dialects.parse_sql_statements(sql).unwrap_err(),

0 commit comments

Comments
 (0)