Skip to content

Commit 7963416

Browse files
authored
chore(query): add python style list comprehension (#13887)
* chore(query): add python style list comprehension * chore(query): add python style list comprehension
1 parent adac7a0 commit 7963416

File tree

5 files changed

+122
-2
lines changed

5 files changed

+122
-2
lines changed

src/query/ast/src/parser/expr.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,13 @@ pub enum ExprElement {
296296
args: Vec<Expr>,
297297
lambda: Option<Lambda>,
298298
},
299+
/// python/rust list comprehension
300+
ListComprehension {
301+
source: Expr,
302+
param: Identifier,
303+
filter: Option<Expr>,
304+
result: Expr,
305+
},
299306
/// An expression between parentheses
300307
Group(Expr),
301308
/// `[1, 2, 3]`
@@ -518,6 +525,44 @@ impl<'a, I: Iterator<Item = WithSpan<'a, ExprElement>>> PrattParser<I> for ExprP
518525
span: transform_span(elem.span.0),
519526
exprs,
520527
},
528+
ExprElement::ListComprehension {
529+
source,
530+
param,
531+
filter,
532+
result,
533+
} => {
534+
let span = transform_span(elem.span.0);
535+
let mut source = source;
536+
537+
// array_filter(source, filter)
538+
if let Some(filter) = filter {
539+
source = Expr::FunctionCall {
540+
span,
541+
distinct: false,
542+
name: Identifier::from_name("array_filter"),
543+
args: vec![source],
544+
params: vec![],
545+
window: None,
546+
lambda: Some(Lambda {
547+
params: vec![param.clone()],
548+
expr: Box::new(filter),
549+
}),
550+
};
551+
}
552+
// array_map(source, result)
553+
Expr::FunctionCall {
554+
span,
555+
distinct: false,
556+
name: Identifier::from_name("array_map"),
557+
args: vec![source],
558+
params: vec![],
559+
window: None,
560+
lambda: Some(Lambda {
561+
params: vec![param.clone()],
562+
expr: Box::new(result),
563+
}),
564+
}
565+
}
521566
ExprElement::Map { kvs } => Expr::Map {
522567
span: transform_span(elem.span.0),
523568
kvs,
@@ -1022,6 +1067,28 @@ pub fn expr_element(i: Input) -> IResult<WithSpan<ExprElement>> {
10221067
)),
10231068
);
10241069

1070+
// python style list comprehensions
1071+
// python: [i for i in range(10) if i%2==0 ]
1072+
// sql: [i for i in range(10) if i%2 = 0 ]
1073+
let list_comprehensions = check_experimental_chain_function(
1074+
true,
1075+
map(
1076+
rule! {
1077+
"[" ~ #subexpr(0) ~ FOR ~ #ident ~ IN
1078+
~ #subexpr(0) ~ (IF ~ #subexpr(2))? ~ "]"
1079+
},
1080+
|(_, result, _, param, _, source, opt_filter, _)| {
1081+
let filter = opt_filter.map(|(_, filter)| filter);
1082+
ExprElement::ListComprehension {
1083+
source,
1084+
param,
1085+
filter,
1086+
result,
1087+
}
1088+
},
1089+
),
1090+
);
1091+
10251092
// Floating point literal with leading dot will be parsed as a period map access,
10261093
// and then will be converted back to a floating point literal if the map access
10271094
// is not following a primary element nor a postfix element.
@@ -1154,6 +1221,7 @@ pub fn expr_element(i: Input) -> IResult<WithSpan<ExprElement>> {
11541221
| #trim_from : "`TRIM([(BOTH | LEADEING | TRAILING) ... FROM ...)`"
11551222
| #is_distinct_from: "`... IS [NOT] DISTINCT FROM ...`"
11561223
| #chain_function_call : "x.func(...)"
1224+
| #list_comprehensions: "[expr for x in ... [if ...]]"
11571225
| #count_all_with_window : "`COUNT(*) OVER ...`"
11581226
| #function_call : "<function>"
11591227
| #case : "`CASE ... END`"
@@ -1684,6 +1752,11 @@ pub fn parse_float(text: &str) -> Result<Literal, ErrorKind> {
16841752

16851753
pub fn parse_uint(text: &str, radix: u32) -> Result<Literal, ErrorKind> {
16861754
let text = text.trim_start_matches('0');
1755+
let contains_underscore = text.contains('_');
1756+
if contains_underscore {
1757+
let text = text.replace(|p| p == '_', "");
1758+
return parse_uint(&text, radix);
1759+
}
16871760

16881761
if text.is_empty() {
16891762
return Ok(Literal::UInt64(0));

src/query/ast/src/parser/token.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ pub enum TokenKind {
159159
#[regex(r"0[xX][a-fA-F0-9]+")]
160160
MySQLLiteralHex,
161161

162-
#[regex(r"[0-9]+")]
162+
#[regex(r"[0-9]+(_|[0-9])*")]
163163
LiteralInteger,
164164

165165
#[regex(r"[0-9]+[eE][+-]?[0-9]+")]
@@ -1358,6 +1358,7 @@ impl TokenKind {
13581358
| TokenKind::FALSE
13591359
// | TokenKind::FOREIGN
13601360
// | TokenKind::FREEZE
1361+
| TokenKind::FOR
13611362
| TokenKind::FULL
13621363
// | TokenKind::ILIKE
13631364
| TokenKind::IN
@@ -1407,7 +1408,6 @@ impl TokenKind {
14071408
| TokenKind::ATTACH
14081409
| TokenKind::EXCEPT
14091410
// | TokenKind::FETCH
1410-
| TokenKind::FOR
14111411
| TokenKind::FROM
14121412
// | TokenKind::GRANT
14131413
| TokenKind::GROUP

src/query/ast/tests/it/parser.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,8 @@ fn test_expr() {
731731
r#"123456789012345678901234567890"#,
732732
r#"x'123456789012345678901234567890'"#,
733733
r#"1e100000000000000"#,
734+
r#"100_100_000"#,
735+
r#"1_12200_00"#,
734736
r#".1"#,
735737
r#"-1"#,
736738
r#"(1)"#,

src/query/ast/tests/it/testdata/expr.txt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,36 @@ Literal {
257257
}
258258

259259

260+
---------- Input ----------
261+
100_100_000
262+
---------- Output ---------
263+
100100000
264+
---------- AST ------------
265+
Literal {
266+
span: Some(
267+
0..11,
268+
),
269+
lit: UInt64(
270+
100100000,
271+
),
272+
}
273+
274+
275+
---------- Input ----------
276+
1_12200_00
277+
---------- Output ---------
278+
11220000
279+
---------- AST ------------
280+
Literal {
281+
span: Some(
282+
0..10,
283+
),
284+
lit: UInt64(
285+
11220000,
286+
),
287+
}
288+
289+
260290
---------- Input ----------
261291
.1
262292
---------- Output ---------

tests/sqllogictests/suites/query/02_function/02_0069_chain_function.test renamed to tests/sqllogictests/suites/query/02_function/02_0069_experimental_expr.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,18 @@ with t(f) as (select '11|open|22|ai|33|is nothing without sam'.split('|')
2121

2222
statement error 1008
2323
SELECT t.a::String.lowe() FROM numbers(1) t(a)
24+
25+
## List Comprehension
26+
27+
query T
28+
select [ x * 100 FOR x in [1,2,3] if x % 2 = 0 ];
29+
----
30+
[200]
31+
32+
33+
query IT
34+
SELECT 12_000_111_222, [x.split(' ')[2]
35+
FOR x IN ['OpenAI', 'I LOVE', 'YOU ALL']
36+
IF x.INSTR('LOVE') > 0][1];
37+
----
38+
12000111222 LOVE

0 commit comments

Comments
 (0)