Skip to content

Commit 6a7aad3

Browse files
committed
Add basic CREATE FUNCTION support for SQL Server
- in this dialect, functions can have statement(s) bodies like stored procedures (including `BEGIN`..`END`) - functions must end with `RETURN`, so a corresponding statement type is also introduced
1 parent 514d2ec commit 6a7aad3

File tree

6 files changed

+275
-9
lines changed

6 files changed

+275
-9
lines changed

src/ast/ddl.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2272,6 +2272,12 @@ impl fmt::Display for CreateFunction {
22722272
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
22732273
write!(f, " AS {function_body}")?;
22742274
}
2275+
if let Some(CreateFunctionBody::MultiStatement(statements)) = &self.function_body {
2276+
write!(f, " AS")?;
2277+
write!(f, " BEGIN")?;
2278+
write!(f, " {}", display_separated(statements, "; "))?;
2279+
write!(f, " END")?;
2280+
}
22752281
Ok(())
22762282
}
22772283
}

src/ast/mod.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3614,6 +3614,7 @@ pub enum Statement {
36143614
/// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction)
36153615
/// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html)
36163616
/// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement)
3617+
/// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
36173618
CreateFunction(CreateFunction),
36183619
/// CREATE TRIGGER
36193620
///
@@ -4054,6 +4055,13 @@ pub enum Statement {
40544055
arguments: Vec<Expr>,
40554056
options: Vec<RaisErrorOption>,
40564057
},
4058+
/// Return (MsSql)
4059+
///
4060+
/// for Functions:
4061+
/// RETURN scalar_expression
4062+
///
4063+
/// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql>
4064+
Return(ReturnStatement),
40574065
}
40584066

40594067
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
@@ -5745,7 +5753,7 @@ impl fmt::Display for Statement {
57455753
}
57465754
Ok(())
57475755
}
5748-
5756+
Statement::Return(r) => write!(f, "{r}"),
57495757
Statement::List(command) => write!(f, "LIST {command}"),
57505758
Statement::Remove(command) => write!(f, "REMOVE {command}"),
57515759
}
@@ -8348,6 +8356,7 @@ impl fmt::Display for FunctionDeterminismSpecifier {
83488356
///
83498357
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83508358
/// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html
8359+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
83518360
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
83528361
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
83538362
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@@ -8376,6 +8385,22 @@ pub enum CreateFunctionBody {
83768385
///
83778386
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83788387
AsAfterOptions(Expr),
8388+
/// Function body with statements before the `RETURN` keyword.
8389+
///
8390+
/// Example:
8391+
/// ```sql
8392+
/// CREATE FUNCTION my_scalar_udf(a INT, b INT)
8393+
/// RETURNS INT
8394+
/// AS
8395+
/// BEGIN
8396+
/// DECLARE c INT;
8397+
/// SET c = a + b;
8398+
/// RETURN c;
8399+
/// END
8400+
/// ```
8401+
///
8402+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
8403+
MultiStatement(Vec<Statement>),
83798404
/// Function body expression using the 'RETURN' keyword.
83808405
///
83818406
/// Example:
@@ -9211,6 +9236,30 @@ pub enum CopyIntoSnowflakeKind {
92119236
Location,
92129237
}
92139238

9239+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9240+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9241+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9242+
pub struct ReturnStatement {
9243+
pub value: Option<ReturnStatementValue>,
9244+
}
9245+
9246+
impl fmt::Display for ReturnStatement {
9247+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
9248+
match &self.value {
9249+
Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr),
9250+
None => write!(f, "RETURN"),
9251+
}
9252+
}
9253+
}
9254+
9255+
/// Variants of a `RETURN` statement
9256+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9257+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9258+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9259+
pub enum ReturnStatementValue {
9260+
Expr(Expr),
9261+
}
9262+
92149263
#[cfg(test)]
92159264
mod tests {
92169265
use super::*;

src/ast/spans.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ impl Spanned for Statement {
519519
Statement::UNLISTEN { .. } => Span::empty(),
520520
Statement::RenameTable { .. } => Span::empty(),
521521
Statement::RaisError { .. } => Span::empty(),
522+
Statement::Return { .. } => Span::empty(),
522523
Statement::List(..) | Statement::Remove(..) => Span::empty(),
523524
}
524525
}

src/parser/mod.rs

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -577,13 +577,7 @@ impl<'a> Parser<'a> {
577577
Keyword::GRANT => self.parse_grant(),
578578
Keyword::REVOKE => self.parse_revoke(),
579579
Keyword::START => self.parse_start_transaction(),
580-
// `BEGIN` is a nonstandard but common alias for the
581-
// standard `START TRANSACTION` statement. It is supported
582-
// by at least PostgreSQL and MySQL.
583580
Keyword::BEGIN => self.parse_begin(),
584-
// `END` is a nonstandard but common alias for the
585-
// standard `COMMIT TRANSACTION` statement. It is supported
586-
// by PostgreSQL.
587581
Keyword::END => self.parse_end(),
588582
Keyword::SAVEPOINT => self.parse_savepoint(),
589583
Keyword::RELEASE => self.parse_release(),
@@ -617,6 +611,7 @@ impl<'a> Parser<'a> {
617611
}
618612
// `COMMENT` is snowflake specific https://docs.snowflake.com/en/sql-reference/sql/comment
619613
Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(),
614+
Keyword::RETURN => self.parse_return(),
620615
_ => self.expected("an SQL statement", next_token),
621616
},
622617
Token::LParen => {
@@ -4881,6 +4876,8 @@ impl<'a> Parser<'a> {
48814876
self.parse_create_macro(or_replace, temporary)
48824877
} else if dialect_of!(self is BigQueryDialect) {
48834878
self.parse_bigquery_create_function(or_replace, temporary)
4879+
} else if dialect_of!(self is MsSqlDialect) {
4880+
self.parse_mssql_create_function(or_replace, temporary)
48844881
} else {
48854882
self.prev_token();
48864883
self.expected("an object type after CREATE", self.peek_token())
@@ -5135,6 +5132,72 @@ impl<'a> Parser<'a> {
51355132
}))
51365133
}
51375134

5135+
/// Parse `CREATE FUNCTION` for [MsSql]
5136+
///
5137+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
5138+
fn parse_mssql_create_function(
5139+
&mut self,
5140+
or_replace: bool,
5141+
temporary: bool,
5142+
) -> Result<Statement, ParserError> {
5143+
let name = self.parse_object_name(false)?;
5144+
5145+
let parse_function_param =
5146+
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
5147+
let name = parser.parse_identifier()?;
5148+
let data_type = parser.parse_data_type()?;
5149+
Ok(OperateFunctionArg {
5150+
mode: None,
5151+
name: Some(name),
5152+
data_type,
5153+
default_expr: None,
5154+
})
5155+
};
5156+
self.expect_token(&Token::LParen)?;
5157+
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
5158+
self.expect_token(&Token::RParen)?;
5159+
5160+
let return_type = if self.parse_keyword(Keyword::RETURNS) {
5161+
Some(self.parse_data_type()?)
5162+
} else {
5163+
return parser_err!("Expected RETURNS keyword", self.peek_token().span.start);
5164+
};
5165+
5166+
self.expect_keyword_is(Keyword::AS)?;
5167+
self.expect_keyword_is(Keyword::BEGIN)?;
5168+
let mut result = self.parse_statements()?;
5169+
// note: `parse_statements` will consume the `END` token & produce a Commit statement...
5170+
if let Some(Statement::Commit {
5171+
chain,
5172+
end,
5173+
modifier,
5174+
}) = result.last()
5175+
{
5176+
if *chain == false && *end == true && *modifier == None {
5177+
result = result[..result.len() - 1].to_vec();
5178+
}
5179+
}
5180+
let function_body = Some(CreateFunctionBody::MultiStatement(result));
5181+
5182+
Ok(Statement::CreateFunction(CreateFunction {
5183+
or_replace,
5184+
temporary,
5185+
if_not_exists: false,
5186+
name,
5187+
args: Some(args),
5188+
return_type,
5189+
function_body,
5190+
language: None,
5191+
determinism_specifier: None,
5192+
options: None,
5193+
remote_connection: None,
5194+
using: None,
5195+
behavior: None,
5196+
called_on_null: None,
5197+
parallel: None,
5198+
}))
5199+
}
5200+
51385201
fn parse_function_arg(&mut self) -> Result<OperateFunctionArg, ParserError> {
51395202
let mode = if self.parse_keyword(Keyword::IN) {
51405203
Some(ArgMode::In)
@@ -15058,6 +15121,14 @@ impl<'a> Parser<'a> {
1505815121
}
1505915122
}
1506015123

15124+
/// Parse [Statement::Return]
15125+
fn parse_return(&mut self) -> Result<Statement, ParserError> {
15126+
let expr = self.parse_expr()?;
15127+
Ok(Statement::Return(ReturnStatement {
15128+
value: Some(ReturnStatementValue::Expr(expr)),
15129+
}))
15130+
}
15131+
1506115132
/// Consume the parser and return its underlying token buffer
1506215133
pub fn into_tokens(self) -> Vec<TokenWithSpan> {
1506315134
self.tokens

tests/sqlparser_hive.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use sqlparser::ast::{
2525
Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr,
2626
OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value,
2727
};
28-
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
28+
use sqlparser::dialect::{AnsiDialect, GenericDialect, HiveDialect};
2929
use sqlparser::parser::ParserError;
3030
use sqlparser::test_utils::*;
3131

@@ -423,7 +423,7 @@ fn parse_create_function() {
423423
}
424424

425425
// Test error in dialect that doesn't support parsing CREATE FUNCTION
426-
let unsupported_dialects = TestedDialects::new(vec![Box::new(MsSqlDialect {})]);
426+
let unsupported_dialects = TestedDialects::new(vec![Box::new(AnsiDialect {})]);
427427

428428
assert_eq!(
429429
unsupported_dialects.parse_sql_statements(sql).unwrap_err(),

tests/sqlparser_mssql.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,145 @@ fn parse_mssql_create_procedure() {
187187
let _ = ms().verified_stmt("CREATE PROCEDURE [foo] AS BEGIN UPDATE bar SET col = 'test'; SELECT [foo] FROM BAR WHERE [FOO] > 10 END");
188188
}
189189

190+
#[test]
191+
fn parse_create_function() {
192+
let return_expression_function = "CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) RETURNS INT AS BEGIN RETURN 1 END";
193+
assert_eq!(
194+
ms().verified_stmt(return_expression_function),
195+
sqlparser::ast::Statement::CreateFunction(CreateFunction {
196+
or_replace: false,
197+
temporary: false,
198+
if_not_exists: false,
199+
name: ObjectName::from(vec![Ident {
200+
value: "some_scalar_udf".into(),
201+
quote_style: None,
202+
span: Span::empty(),
203+
}]),
204+
args: Some(vec![
205+
OperateFunctionArg {
206+
mode: None,
207+
name: Some(Ident {
208+
value: "@foo".into(),
209+
quote_style: None,
210+
span: Span::empty(),
211+
}),
212+
data_type: DataType::Int(None),
213+
default_expr: None,
214+
},
215+
OperateFunctionArg {
216+
mode: None,
217+
name: Some(Ident {
218+
value: "@bar".into(),
219+
quote_style: None,
220+
span: Span::empty(),
221+
}),
222+
data_type: DataType::Varchar(Some(CharacterLength::IntegerLength {
223+
length: 256,
224+
unit: None
225+
})),
226+
default_expr: None,
227+
},
228+
]),
229+
return_type: Some(DataType::Int(None)),
230+
function_body: Some(CreateFunctionBody::MultiStatement(vec![
231+
Statement::Return(ReturnStatement {
232+
value: Some(ReturnStatementValue::Expr(Expr::Value((number("1")).with_empty_span()))),
233+
}),
234+
])),
235+
behavior: None,
236+
called_on_null: None,
237+
parallel: None,
238+
using: None,
239+
language: None,
240+
determinism_specifier: None,
241+
options: None,
242+
remote_connection: None,
243+
}),
244+
);
245+
246+
let multi_statement_function = "\
247+
CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \
248+
RETURNS INT \
249+
AS \
250+
BEGIN \
251+
SET @foo = @foo + 1; \
252+
RETURN @foo \
253+
END\
254+
";
255+
assert_eq!(
256+
ms().verified_stmt(multi_statement_function),
257+
sqlparser::ast::Statement::CreateFunction(CreateFunction {
258+
or_replace: false,
259+
temporary: false,
260+
if_not_exists: false,
261+
name: ObjectName::from(vec![Ident {
262+
value: "some_scalar_udf".into(),
263+
quote_style: None,
264+
span: Span::empty(),
265+
}]),
266+
args: Some(vec![
267+
OperateFunctionArg {
268+
mode: None,
269+
name: Some(Ident {
270+
value: "@foo".into(),
271+
quote_style: None,
272+
span: Span::empty(),
273+
}),
274+
data_type: DataType::Int(None),
275+
default_expr: None,
276+
},
277+
OperateFunctionArg {
278+
mode: None,
279+
name: Some(Ident {
280+
value: "@bar".into(),
281+
quote_style: None,
282+
span: Span::empty(),
283+
}),
284+
data_type: DataType::Varchar(Some(CharacterLength::IntegerLength {
285+
length: 256,
286+
unit: None
287+
})),
288+
default_expr: None,
289+
},
290+
]),
291+
return_type: Some(DataType::Int(None)),
292+
function_body: Some(CreateFunctionBody::MultiStatement(vec![
293+
Statement::Set(Set::SingleAssignment {
294+
scope: None,
295+
hivevar: false,
296+
variable: ObjectName::from(vec!["@foo".into()]),
297+
values: vec![sqlparser::ast::Expr::BinaryOp {
298+
left: Box::new(sqlparser::ast::Expr::Identifier(Ident {
299+
value: "@foo".to_string(),
300+
quote_style: None,
301+
span: Span::empty(),
302+
})),
303+
op: sqlparser::ast::BinaryOperator::Plus,
304+
right: Box::new(Expr::Value(
305+
(Value::Number("1".into(), false)).with_empty_span()
306+
)),
307+
}],
308+
}),
309+
Statement::Return(ReturnStatement{
310+
value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident {
311+
value: "@foo".into(),
312+
quote_style: None,
313+
span: Span::empty(),
314+
}))),
315+
}),
316+
])),
317+
behavior: None,
318+
called_on_null: None,
319+
parallel: None,
320+
using: None,
321+
language: None,
322+
determinism_specifier: None,
323+
options: None,
324+
remote_connection: None,
325+
}),
326+
);
327+
}
328+
190329
#[test]
191330
fn parse_mssql_apply_join() {
192331
let _ = ms_and_generic().verified_only_select(

0 commit comments

Comments
 (0)