Skip to content

Commit e7b3d00

Browse files
authored
feat(sqlsmith): Support generate subquery and with clause (#12956)
1 parent 9ca40ec commit e7b3d00

File tree

4 files changed

+212
-24
lines changed

4 files changed

+212
-24
lines changed

src/query/functions/src/scalars/vector.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,23 @@ pub fn register(registry: &mut FunctionRegistry) {
9090
return;
9191
}
9292
}
93-
let data = std::str::from_utf8(data).unwrap();
9493

94+
let data = match std::str::from_utf8(data) {
95+
Ok(data) => data,
96+
Err(_) => {
97+
ctx.set_error(
98+
output.len(),
99+
format!("Invalid data: {:?}", String::from_utf8_lossy(data)),
100+
);
101+
output.push(vec![F32::from(0.0)].into());
102+
return;
103+
}
104+
};
105+
if ctx.func_ctx.openai_api_key.is_empty() {
106+
ctx.set_error(output.len(), "openai_api_key is empty".to_string());
107+
output.push(vec![F32::from(0.0)].into());
108+
return;
109+
}
95110
let api_base = ctx.func_ctx.openai_api_embedding_base_url.clone();
96111
let api_key = ctx.func_ctx.openai_api_key.clone();
97112
let api_version = ctx.func_ctx.openai_api_version.clone();
@@ -140,7 +155,24 @@ pub fn register(registry: &mut FunctionRegistry) {
140155
}
141156
}
142157

143-
let data = std::str::from_utf8(data).unwrap();
158+
let data = match std::str::from_utf8(data) {
159+
Ok(data) => data,
160+
Err(_) => {
161+
ctx.set_error(
162+
output.len(),
163+
format!("Invalid data: {:?}", String::from_utf8_lossy(data)),
164+
);
165+
output.put_str("");
166+
output.commit_row();
167+
return;
168+
}
169+
};
170+
if ctx.func_ctx.openai_api_key.is_empty() {
171+
ctx.set_error(output.len(), "openai_api_key is empty".to_string());
172+
output.put_str("");
173+
output.commit_row();
174+
return;
175+
}
144176
let api_base = ctx.func_ctx.openai_api_chat_base_url.clone();
145177
let api_key = ctx.func_ctx.openai_api_key.clone();
146178
let api_version = ctx.func_ctx.openai_api_version.clone();

src/tests/sqlsmith/src/sql_gen/expr.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
4545
}
4646
}
4747

48+
pub(crate) fn gen_simple_expr(&mut self, ty: &DataType) -> Expr {
49+
if self.rng.gen_bool(0.6) {
50+
self.gen_column(ty)
51+
} else {
52+
self.gen_scalar_value(ty)
53+
}
54+
}
55+
4856
fn gen_column(&mut self, ty: &DataType) -> Expr {
4957
for bound_column in &self.bound_columns {
5058
if bound_column.data_type == *ty {
@@ -389,7 +397,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
389397
}
390398
7 => {
391399
let not = self.rng.gen_bool(0.5);
392-
let subquery = self.gen_subquery();
400+
let (subquery, _) = self.gen_subquery(false);
393401
Expr::Exists {
394402
span: None,
395403
not,
@@ -404,7 +412,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
404412
3 => Some(SubqueryModifier::Some),
405413
_ => unreachable!(),
406414
};
407-
let subquery = self.gen_subquery();
415+
let (subquery, _) = self.gen_subquery(true);
408416
Expr::Subquery {
409417
span: None,
410418
modifier,
@@ -415,7 +423,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
415423
let expr_ty = self.gen_simple_data_type();
416424
let expr = self.gen_expr(&expr_ty);
417425
let not = self.rng.gen_bool(0.5);
418-
let subquery = self.gen_subquery();
426+
let (subquery, _) = self.gen_subquery(true);
419427
Expr::InSubquery {
420428
span: None,
421429
expr: Box::new(expr),

src/tests/sqlsmith/src/sql_gen/query.rs

Lines changed: 162 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,18 @@ use common_ast::ast::Query;
2626
use common_ast::ast::SelectStmt;
2727
use common_ast::ast::SelectTarget;
2828
use common_ast::ast::SetExpr;
29+
use common_ast::ast::TableAlias;
2930
use common_ast::ast::TableReference;
31+
use common_ast::ast::With;
32+
use common_ast::ast::CTE;
33+
use common_expression::infer_schema_type;
3034
use common_expression::types::DataType;
3135
use common_expression::types::NumberDataType;
3236
use common_expression::TableDataType;
3337
use common_expression::TableField;
38+
use common_expression::TableSchemaRef;
3439
use common_expression::TableSchemaRefExt;
40+
use rand::distributions::Alphanumeric;
3541
use rand::Rng;
3642

3743
use crate::sql_gen::Column;
@@ -40,19 +46,20 @@ use crate::sql_gen::Table;
4046

4147
impl<'a, R: Rng> SqlGenerator<'a, R> {
4248
pub(crate) fn gen_query(&mut self) -> Query {
43-
self.bound_columns.clear();
49+
self.cte_tables.clear();
4450
self.bound_tables.clear();
51+
self.bound_columns.clear();
4552
self.is_join = false;
4653

54+
let with = self.gen_with();
4755
let body = self.gen_set_expr();
4856
let limit = self.gen_limit();
4957
let offset = self.gen_offset(limit.len());
5058
let order_by = self.gen_order_by(self.group_by.clone());
5159

5260
Query {
5361
span: None,
54-
// TODO
55-
with: None,
62+
with,
5663
body,
5764
order_by,
5865
limit,
@@ -61,7 +68,10 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
6168
}
6269
}
6370

64-
pub(crate) fn gen_subquery(&mut self) -> Query {
71+
// Scalar, IN / NOT IN, ANY / SOME / ALL Subquery must return only one column
72+
// EXISTS / NOT EXISTS Subquery can return any columns
73+
pub(crate) fn gen_subquery(&mut self, one_column: bool) -> (Query, TableSchemaRef) {
74+
let current_cte_tables = mem::take(&mut self.cte_tables);
6575
let current_bound_tables = mem::take(&mut self.bound_tables);
6676
let current_bound_columns = mem::take(&mut self.bound_columns);
6777
let current_is_join = self.is_join;
@@ -70,13 +80,101 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
7080
self.bound_columns = vec![];
7181
self.is_join = false;
7282

73-
let query = self.gen_query();
83+
// Only generate simple subquery
84+
// TODO: complex subquery
85+
let from = self.gen_from();
86+
87+
let len = if one_column {
88+
1
89+
} else {
90+
self.rng.gen_range(1..=5)
91+
};
92+
93+
let name: String = (0..3)
94+
.map(|_| self.rng.sample(Alphanumeric) as char)
95+
.collect();
96+
let mut fields = Vec::with_capacity(len);
97+
let mut select_list = Vec::with_capacity(len);
98+
for i in 0..len {
99+
let ty = self.gen_simple_data_type();
100+
let expr = self.gen_simple_expr(&ty);
101+
let col_name = format!("c{}{}", name, i);
102+
let table_type = infer_schema_type(&ty).unwrap();
103+
let field = TableField::new(&col_name, table_type);
104+
fields.push(field);
105+
let alias = Identifier::from_name(col_name);
106+
let target = SelectTarget::AliasedExpr {
107+
expr: Box::new(expr),
108+
alias: Some(alias),
109+
};
110+
select_list.push(target);
111+
}
112+
let schema = TableSchemaRefExt::create(fields);
74113

114+
let select = SelectStmt {
115+
span: None,
116+
hints: None,
117+
distinct: false,
118+
select_list,
119+
from,
120+
selection: None,
121+
group_by: None,
122+
having: None,
123+
window_list: None,
124+
};
125+
let body = SetExpr::Select(Box::new(select));
126+
127+
let query = Query {
128+
span: None,
129+
with: None,
130+
body,
131+
order_by: vec![],
132+
limit: vec![],
133+
offset: None,
134+
ignore_result: false,
135+
};
136+
137+
self.cte_tables = current_cte_tables;
75138
self.bound_tables = current_bound_tables;
76139
self.bound_columns = current_bound_columns;
77140
self.is_join = current_is_join;
78141

79-
query
142+
(query, schema)
143+
}
144+
145+
fn gen_with(&mut self) -> Option<With> {
146+
if self.rng.gen_bool(0.8) {
147+
return None;
148+
}
149+
150+
let len = self.rng.gen_range(1..=3);
151+
let mut ctes = Vec::with_capacity(len);
152+
for _ in 0..len {
153+
let cte = self.gen_cte();
154+
ctes.push(cte);
155+
}
156+
157+
Some(With {
158+
span: None,
159+
recursive: false,
160+
ctes,
161+
})
162+
}
163+
164+
fn gen_cte(&mut self) -> CTE {
165+
let (subquery, schema) = self.gen_subquery(false);
166+
167+
let (table, alias) = self.gen_subquery_table(schema);
168+
self.cte_tables.push(table);
169+
170+
let materialized = self.rng.gen_bool(0.5);
171+
172+
CTE {
173+
span: None,
174+
alias,
175+
materialized,
176+
query: Box::new(subquery),
177+
}
80178
}
81179

82180
fn gen_set_expr(&mut self) -> SetExpr {
@@ -304,17 +402,21 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
304402
// TODO: generate more table reference
305403
// let table_ref_num = self.rng.gen_range(1..=3);
306404
match self.rng.gen_range(0..=10) {
307-
0..=7 => {
308-
let i = self.rng.gen_range(0..self.tables.len());
309-
let table_ref = self.gen_table_ref(self.tables[i].clone());
405+
0..=6 => {
406+
let (table_ref, _) = self.gen_table_ref();
310407
table_refs.push(table_ref);
311408
}
312409
// join
313-
8..=9 => {
410+
7..=8 => {
314411
self.is_join = true;
315412
let join = self.gen_join_table_ref();
316413
table_refs.push(join);
317414
}
415+
// subquery
416+
9 => {
417+
let subquery = self.gen_subquery_table_ref();
418+
table_refs.push(subquery);
419+
}
318420
10 => {
319421
let table_func = self.gen_table_func();
320422
table_refs.push(table_func);
@@ -325,12 +427,21 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
325427
table_refs
326428
}
327429

328-
fn gen_table_ref(&mut self, table: Table) -> TableReference {
430+
fn gen_table_ref(&mut self) -> (TableReference, TableSchemaRef) {
431+
let len = self.tables.len() + self.cte_tables.len();
432+
let i = self.rng.gen_range(0..len);
433+
434+
let table = if i < self.tables.len() {
435+
self.tables[i].clone()
436+
} else {
437+
self.cte_tables[len - i - 1].clone()
438+
};
439+
let schema = table.schema.clone();
329440
let table_name = Identifier::from_name(table.name.clone());
330441

331442
self.bound_table(table);
332443

333-
TableReference::Table {
444+
let table_ref = TableReference::Table {
334445
span: None,
335446
// TODO
336447
catalog: None,
@@ -345,7 +456,8 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
345456
pivot: None,
346457
// TODO
347458
unpivot: None,
348-
}
459+
};
460+
(table_ref, schema)
349461
}
350462

351463
// Only test:
@@ -453,11 +565,10 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
453565
_ => unreachable!(),
454566
}
455567
}
568+
456569
fn gen_join_table_ref(&mut self) -> TableReference {
457-
let i = self.rng.gen_range(0..self.tables.len());
458-
let j = if i == self.tables.len() - 1 { 0 } else { i + 1 };
459-
let left_table = self.gen_table_ref(self.tables[i].clone());
460-
let right_table = self.gen_table_ref(self.tables[j].clone());
570+
let (left_table, left_schema) = self.gen_table_ref();
571+
let (right_table, right_schema) = self.gen_table_ref();
461572

462573
let op = match self.rng.gen_range(0..=8) {
463574
0 => JoinOperator::Inner,
@@ -479,8 +590,8 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
479590
JoinCondition::On(Box::new(expr))
480591
}
481592
1 => {
482-
let left_fields = self.tables[i].schema.fields();
483-
let right_fields = self.tables[j].schema.fields();
593+
let left_fields = left_schema.fields();
594+
let right_fields = right_schema.fields();
484595

485596
let mut names = Vec::new();
486597
for left_field in left_fields {
@@ -534,6 +645,19 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
534645
TableReference::Join { span: None, join }
535646
}
536647

648+
fn gen_subquery_table_ref(&mut self) -> TableReference {
649+
let (subquery, schema) = self.gen_subquery(false);
650+
651+
let (table, alias) = self.gen_subquery_table(schema);
652+
self.bound_table(table);
653+
654+
TableReference::Subquery {
655+
span: None,
656+
subquery: Box::new(subquery),
657+
alias: Some(alias),
658+
}
659+
}
660+
537661
fn gen_selection(&mut self) -> Option<Expr> {
538662
match self.rng.gen_range(0..=9) {
539663
0..=5 => Some(self.gen_expr(&DataType::Boolean)),
@@ -545,6 +669,25 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
545669
}
546670
}
547671

672+
fn gen_subquery_table(&mut self, schema: TableSchemaRef) -> (Table, TableAlias) {
673+
let name: String = (0..4)
674+
.map(|_| self.rng.sample(Alphanumeric) as char)
675+
.collect();
676+
let table_name = format!("t{}", name);
677+
let mut columns = Vec::with_capacity(schema.num_fields());
678+
for field in schema.fields() {
679+
let column = Identifier::from_name(field.name.clone());
680+
columns.push(column);
681+
}
682+
let alias = TableAlias {
683+
name: Identifier::from_name(table_name.clone()),
684+
columns,
685+
};
686+
let table = Table::new(table_name, schema);
687+
688+
(table, alias)
689+
}
690+
548691
fn bound_table(&mut self, table: Table) {
549692
for (i, field) in table.schema.fields().iter().enumerate() {
550693
let column = Column {

0 commit comments

Comments
 (0)