Skip to content

Commit 072bbbb

Browse files
committed
Refactor BeginEndStatements into a reusable struct, then use for functions
- this lets us disacard the former (unfortunate, bespoke) multi statement parsing because we can just use `parse_statement_list` - however, `parse_statement_list` also needed a small change to allow subsequent statements to come after the final `END`
1 parent c6bbf6f commit 072bbbb

File tree

6 files changed

+226
-136
lines changed

6 files changed

+226
-136
lines changed

src/ast/ddl.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,11 +2277,9 @@ impl fmt::Display for CreateFunction {
22772277
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
22782278
write!(f, " AS {function_body}")?;
22792279
}
2280-
if let Some(CreateFunctionBody::MultiStatement(statements)) = &self.function_body {
2280+
if let Some(CreateFunctionBody::AsBeginEnd(bes)) = &self.function_body {
22812281
write!(f, " AS")?;
2282-
write!(f, " BEGIN")?;
2283-
write!(f, " {}", display_separated(statements, "; "))?;
2284-
write!(f, " END")?;
2282+
write!(f, " {}", bes)?;
22852283
}
22862284
Ok(())
22872285
}

src/ast/mod.rs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,18 +2292,14 @@ pub enum ConditionalStatements {
22922292
/// SELECT 1; SELECT 2; SELECT 3; ...
22932293
Sequence { statements: Vec<Statement> },
22942294
/// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END
2295-
BeginEnd {
2296-
begin_token: AttachedToken,
2297-
statements: Vec<Statement>,
2298-
end_token: AttachedToken,
2299-
},
2295+
BeginEnd(BeginEndStatements),
23002296
}
23012297

23022298
impl ConditionalStatements {
23032299
pub fn statements(&self) -> &Vec<Statement> {
23042300
match self {
23052301
ConditionalStatements::Sequence { statements } => statements,
2306-
ConditionalStatements::BeginEnd { statements, .. } => statements,
2302+
ConditionalStatements::BeginEnd(bes) => &bes.statements,
23072303
}
23082304
}
23092305
}
@@ -2317,12 +2313,34 @@ impl fmt::Display for ConditionalStatements {
23172313
}
23182314
Ok(())
23192315
}
2320-
ConditionalStatements::BeginEnd { statements, .. } => {
2321-
write!(f, "BEGIN ")?;
2322-
format_statement_list(f, statements)?;
2323-
write!(f, " END")
2324-
}
2316+
ConditionalStatements::BeginEnd(bes) => write!(f, "{}", bes),
2317+
}
2318+
}
2319+
}
2320+
2321+
/// A shared representation of `BEGIN`, multiple statements, and `END` tokens.
2322+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2323+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2324+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2325+
pub struct BeginEndStatements {
2326+
pub begin_token: AttachedToken,
2327+
pub statements: Vec<Statement>,
2328+
pub end_token: AttachedToken,
2329+
}
2330+
2331+
impl fmt::Display for BeginEndStatements {
2332+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2333+
let BeginEndStatements {
2334+
begin_token: AttachedToken(begin_token),
2335+
statements,
2336+
end_token: AttachedToken(end_token),
2337+
} = self;
2338+
2339+
write!(f, "{begin_token} ")?;
2340+
if !statements.is_empty() {
2341+
format_statement_list(f, statements)?;
23252342
}
2343+
write!(f, " {end_token}")
23262344
}
23272345
}
23282346

@@ -8406,7 +8424,7 @@ pub enum CreateFunctionBody {
84068424
/// ```
84078425
///
84088426
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
8409-
MultiStatement(Vec<Statement>),
8427+
AsBeginEnd(BeginEndStatements),
84108428
/// Function body expression using the 'RETURN' keyword.
84118429
///
84128430
/// Example:

src/ast/spans.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -779,11 +779,9 @@ impl Spanned for ConditionalStatements {
779779
ConditionalStatements::Sequence { statements } => {
780780
union_spans(statements.iter().map(|s| s.span()))
781781
}
782-
ConditionalStatements::BeginEnd {
783-
begin_token: AttachedToken(start),
784-
statements: _,
785-
end_token: AttachedToken(end),
786-
} => union_spans([start.span, end.span].into_iter()),
782+
ConditionalStatements::BeginEnd(bes) => {
783+
union_spans([bes.begin_token.0.span, bes.end_token.0.span].into_iter())
784+
}
787785
}
788786
}
789787
}

src/dialect/mssql.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
use crate::ast::helpers::attached_token::AttachedToken;
19-
use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement};
19+
use crate::ast::{
20+
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement,
21+
};
2022
use crate::dialect::Dialect;
2123
use crate::keywords::{self, Keyword};
2224
use crate::parser::{Parser, ParserError};
@@ -149,11 +151,11 @@ impl MsSqlDialect {
149151
start_token: AttachedToken(if_token),
150152
condition: Some(condition),
151153
then_token: None,
152-
conditional_statements: ConditionalStatements::BeginEnd {
154+
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
153155
begin_token: AttachedToken(begin_token),
154156
statements,
155157
end_token: AttachedToken(end_token),
156-
},
158+
}),
157159
}
158160
} else {
159161
let stmt = parser.parse_statement()?;
@@ -182,11 +184,11 @@ impl MsSqlDialect {
182184
start_token: AttachedToken(else_token),
183185
condition: None,
184186
then_token: None,
185-
conditional_statements: ConditionalStatements::BeginEnd {
187+
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
186188
begin_token: AttachedToken(begin_token),
187189
statements,
188190
end_token: AttachedToken(end_token),
189-
},
191+
}),
190192
});
191193
} else {
192194
let stmt = parser.parse_statement()?;

src/parser/mod.rs

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4453,9 +4453,17 @@ impl<'a> Parser<'a> {
44534453
break;
44544454
}
44554455
}
4456-
44574456
values.push(self.parse_statement()?);
4458-
self.expect_token(&Token::SemiColon)?;
4457+
4458+
let semi_colon_expected = match values.last() {
4459+
Some(Statement::If(if_statement)) => if_statement.end_token.is_some(),
4460+
Some(_) => true,
4461+
None => false,
4462+
};
4463+
4464+
if semi_colon_expected {
4465+
self.expect_token(&Token::SemiColon)?;
4466+
}
44594467
}
44604468
Ok(values)
44614469
}
@@ -5168,20 +5176,16 @@ impl<'a> Parser<'a> {
51685176
};
51695177

51705178
self.expect_keyword_is(Keyword::AS)?;
5171-
self.expect_keyword_is(Keyword::BEGIN)?;
5172-
let mut result = self.parse_statements()?;
5173-
// note: `parse_statements` will consume the `END` token & produce a Commit statement...
5174-
if let Some(Statement::Commit {
5175-
chain,
5176-
end,
5177-
modifier,
5178-
}) = result.last()
5179-
{
5180-
if *chain == false && *end == true && *modifier == None {
5181-
result = result[..result.len() - 1].to_vec();
5182-
}
5183-
}
5184-
let function_body = Some(CreateFunctionBody::MultiStatement(result));
5179+
5180+
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
5181+
let statements = self.parse_statement_list(&[Keyword::END])?;
5182+
let end_token = self.expect_keyword(Keyword::END)?;
5183+
5184+
let function_body = Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
5185+
begin_token: AttachedToken(begin_token),
5186+
statements,
5187+
end_token: AttachedToken(end_token),
5188+
}));
51855189

51865190
Ok(Statement::CreateFunction(CreateFunction {
51875191
or_alter,

0 commit comments

Comments
 (0)