Skip to content

Commit babd7ac

Browse files
authored
Add partiql-types and literals typing (#389)
1 parent ff3e950 commit babd7ac

File tree

10 files changed

+476
-3
lines changed

10 files changed

+476
-3
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99
### Changed
10+
- *BREAKING:* partiql-logical-planner: moves `NameResolver` to `partiql-ast-passes`
11+
1012
### Added
1113
- Add ability for partiql-extension-ion extension encoding/decoding of `Value` to/from Ion `Element`
14+
- Add `partiql-types` crate that includes data models for PartiQL Types.
15+
- Add `partiql_ast_passes::static_typer` for type annotating the AST.
16+
1217
### Fixes
1318

1419
## [0.5.0] - 2023-06-06

partiql-ast-passes/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@ bench = false
2222
[dependencies]
2323
partiql-ast = { path = "../partiql-ast", version = "0.5.*" }
2424
partiql-catalog = { path = "../partiql-catalog", version = "0.5.*" }
25+
partiql-types = { path = "../partiql-types", version = "0.5.*" }
2526

27+
assert_matches = "1.5.*"
2628
fnv = "1"
2729
indexmap = "1.9"
2830
thiserror = "1.0"
2931

3032
[dev-dependencies]
33+
partiql-parser = { path = "../partiql-parser", version = "0.5.*" }
3134

3235
[features]
3336
default = []

partiql-ast-passes/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
77
pub mod error;
88
pub mod name_resolver;
9+
pub mod static_typer;
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
use crate::error::{AstTransformError, AstTransformationError};
2+
use partiql_ast::ast::{
3+
AstNode, AstTypeMap, Bag, Expr, List, Lit, NodeId, Query, QuerySet, Struct,
4+
};
5+
use partiql_ast::visit::{Traverse, Visit, Visitor};
6+
use partiql_catalog::Catalog;
7+
use partiql_types::{ArrayType, BagType, StaticType, StaticTypeKind, StructType};
8+
9+
#[derive(Debug, Clone)]
10+
#[allow(dead_code)]
11+
pub struct AstStaticTyper<'c> {
12+
id_stack: Vec<NodeId>,
13+
container_stack: Vec<Vec<StaticType>>,
14+
errors: Vec<AstTransformError>,
15+
type_map: AstTypeMap<StaticType>,
16+
catalog: &'c dyn Catalog,
17+
}
18+
19+
impl<'c> AstStaticTyper<'c> {
20+
pub fn new(catalog: &'c dyn Catalog) -> Self {
21+
AstStaticTyper {
22+
id_stack: Default::default(),
23+
container_stack: Default::default(),
24+
errors: Default::default(),
25+
type_map: Default::default(),
26+
catalog,
27+
}
28+
}
29+
30+
pub fn type_nodes(
31+
mut self,
32+
query: &AstNode<Query>,
33+
) -> Result<AstTypeMap<StaticType>, AstTransformationError> {
34+
query.visit(&mut self);
35+
if self.errors.is_empty() {
36+
Ok(self.type_map)
37+
} else {
38+
Err(AstTransformationError {
39+
errors: self.errors,
40+
})
41+
}
42+
}
43+
44+
#[inline]
45+
fn current_node(&self) -> &NodeId {
46+
self.id_stack.last().unwrap()
47+
}
48+
}
49+
50+
impl<'c, 'ast> Visitor<'ast> for AstStaticTyper<'c> {
51+
fn enter_ast_node(&mut self, id: NodeId) -> Traverse {
52+
self.id_stack.push(id);
53+
Traverse::Continue
54+
}
55+
56+
fn exit_ast_node(&mut self, id: NodeId) -> Traverse {
57+
assert_eq!(self.id_stack.pop(), Some(id));
58+
Traverse::Continue
59+
}
60+
61+
fn enter_query(&mut self, _query: &'ast Query) -> Traverse {
62+
Traverse::Continue
63+
}
64+
65+
fn exit_query(&mut self, _query: &'ast Query) -> Traverse {
66+
Traverse::Continue
67+
}
68+
69+
fn enter_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse {
70+
match _query_set {
71+
QuerySet::SetOp(_) => {
72+
todo!()
73+
}
74+
QuerySet::Select(_) => {}
75+
QuerySet::Expr(_) => {}
76+
QuerySet::Values(_) => {
77+
todo!()
78+
}
79+
QuerySet::Table(_) => {
80+
todo!()
81+
}
82+
}
83+
Traverse::Continue
84+
}
85+
86+
fn exit_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse {
87+
Traverse::Continue
88+
}
89+
90+
fn enter_expr(&mut self, _expr: &'ast Expr) -> Traverse {
91+
Traverse::Continue
92+
}
93+
94+
fn exit_expr(&mut self, _expr: &'ast Expr) -> Traverse {
95+
Traverse::Continue
96+
}
97+
98+
fn enter_lit(&mut self, _lit: &'ast Lit) -> Traverse {
99+
// Currently we're assuming no-schema, hence typing to arbitrary sized scalars.
100+
// TODO type to the corresponding scalar with the introduction of schema
101+
let kind = match _lit {
102+
Lit::Null => StaticTypeKind::Null,
103+
Lit::Missing => StaticTypeKind::Missing,
104+
Lit::Int8Lit(_) => StaticTypeKind::Int,
105+
Lit::Int16Lit(_) => StaticTypeKind::Int,
106+
Lit::Int32Lit(_) => StaticTypeKind::Int,
107+
Lit::Int64Lit(_) => StaticTypeKind::Int,
108+
Lit::DecimalLit(_) => StaticTypeKind::Decimal,
109+
Lit::NumericLit(_) => StaticTypeKind::Decimal,
110+
Lit::RealLit(_) => StaticTypeKind::Float64,
111+
Lit::FloatLit(_) => StaticTypeKind::Float64,
112+
Lit::DoubleLit(_) => StaticTypeKind::Float64,
113+
Lit::BoolLit(_) => StaticTypeKind::Bool,
114+
Lit::IonStringLit(_) => todo!(),
115+
Lit::CharStringLit(_) => StaticTypeKind::String,
116+
Lit::NationalCharStringLit(_) => StaticTypeKind::String,
117+
Lit::BitStringLit(_) => todo!(),
118+
Lit::HexStringLit(_) => todo!(),
119+
Lit::StructLit(_) => StaticTypeKind::Struct(StructType::unconstrained()),
120+
Lit::ListLit(_) => StaticTypeKind::Array(ArrayType::array()),
121+
Lit::BagLit(_) => StaticTypeKind::Bag(BagType::bag()),
122+
Lit::TypedLit(_, _) => todo!(),
123+
};
124+
125+
let ty = StaticType::new(kind);
126+
let id = *self.current_node();
127+
if let Some(c) = self.container_stack.last_mut() {
128+
c.push(ty.clone())
129+
}
130+
self.type_map.insert(id, ty);
131+
Traverse::Continue
132+
}
133+
134+
fn enter_struct(&mut self, _struct: &'ast Struct) -> Traverse {
135+
self.container_stack.push(vec![]);
136+
Traverse::Continue
137+
}
138+
139+
fn exit_struct(&mut self, _struct: &'ast Struct) -> Traverse {
140+
let id = *self.current_node();
141+
let fields = self.container_stack.pop();
142+
143+
// Such type checking will very likely move to a common module
144+
// TODO move to a more appropriate place for re-use.
145+
if let Some(f) = fields {
146+
// We already fail during parsing if the struct has wrong number of key-value pairs, e.g.:
147+
// {'a', 1, 'b'}
148+
// However, adding this check here.
149+
let is_malformed = f.len() % 2 > 0;
150+
if is_malformed {
151+
self.errors.push(AstTransformError::IllegalState(
152+
"Struct key-value pairs are malformed".to_string(),
153+
));
154+
}
155+
156+
let has_invalid_keys = f.chunks(2).map(|t| &t[0]).any(|t| !t.is_string());
157+
if has_invalid_keys || is_malformed {
158+
self.errors.push(AstTransformError::IllegalState(
159+
"Struct keys can only resolve to `String` type".to_string(),
160+
));
161+
}
162+
}
163+
164+
let ty = StaticType::new_struct(StructType::unconstrained());
165+
self.type_map.insert(id, ty.clone());
166+
167+
if let Some(c) = self.container_stack.last_mut() {
168+
c.push(ty)
169+
}
170+
171+
Traverse::Continue
172+
}
173+
174+
fn enter_bag(&mut self, _bag: &'ast Bag) -> Traverse {
175+
self.container_stack.push(vec![]);
176+
Traverse::Continue
177+
}
178+
179+
fn exit_bag(&mut self, _bag: &'ast Bag) -> Traverse {
180+
// TODO add schema validation of BAG elements, e.g. for Schema Bag<Int> if there is at least
181+
// one element that isn't INT there is a type checking error.
182+
183+
// TODO clarify if we need to record the internal types of bag literal or stick w/Schema?
184+
self.container_stack.pop();
185+
186+
let id = *self.current_node();
187+
let ty = StaticType::new_bag(BagType::bag());
188+
189+
self.type_map.insert(id, ty.clone());
190+
if let Some(s) = self.container_stack.last_mut() {
191+
s.push(ty)
192+
}
193+
Traverse::Continue
194+
}
195+
196+
fn enter_list(&mut self, _list: &'ast List) -> Traverse {
197+
self.container_stack.push(vec![]);
198+
Traverse::Continue
199+
}
200+
201+
fn exit_list(&mut self, _list: &'ast List) -> Traverse {
202+
// TODO clarify if we need to record the internal types of array literal or stick w/Schema?
203+
// one element that isn't INT there is a type checking error.
204+
205+
// TODO clarify if we need to record the internal types of array literal or stick w/Schema?
206+
self.container_stack.pop();
207+
208+
let id = *self.current_node();
209+
let ty = StaticType::new_array(ArrayType::array());
210+
211+
self.type_map.insert(id, ty.clone());
212+
if let Some(s) = self.container_stack.last_mut() {
213+
s.push(ty)
214+
}
215+
Traverse::Continue
216+
}
217+
}
218+
219+
#[cfg(test)]
220+
mod tests {
221+
use super::*;
222+
use assert_matches::assert_matches;
223+
use partiql_ast::ast;
224+
use partiql_catalog::PartiqlCatalog;
225+
use partiql_types::{StaticType, StaticTypeKind};
226+
227+
#[test]
228+
fn simple_test() {
229+
assert_matches!(run_literal_test("NULL"), StaticTypeKind::Null);
230+
assert_matches!(run_literal_test("MISSING"), StaticTypeKind::Missing);
231+
assert_matches!(run_literal_test("Missing"), StaticTypeKind::Missing);
232+
assert_matches!(run_literal_test("true"), StaticTypeKind::Bool);
233+
assert_matches!(run_literal_test("false"), StaticTypeKind::Bool);
234+
assert_matches!(run_literal_test("1"), StaticTypeKind::Int);
235+
assert_matches!(run_literal_test("1.5"), StaticTypeKind::Decimal);
236+
assert_matches!(run_literal_test("'hello world!'"), StaticTypeKind::String);
237+
assert_matches!(
238+
run_literal_test("[1, 2 , {'a': 2}]"),
239+
StaticTypeKind::Array(_)
240+
);
241+
assert_matches!(
242+
run_literal_test("<<'1', {'a': 11}>>"),
243+
StaticTypeKind::Bag(_)
244+
);
245+
assert_matches!(
246+
run_literal_test("{'a': 1, 'b': 3, 'c': [1, 2]}"),
247+
StaticTypeKind::Struct(_)
248+
);
249+
}
250+
251+
#[test]
252+
fn simple_err_test() {
253+
assert!(type_statement("{'a': 1, a.b: 3}").is_err());
254+
}
255+
256+
fn run_literal_test(q: &str) -> StaticTypeKind {
257+
let out = type_statement(q).expect("type map");
258+
let values: Vec<&StaticType> = out.values().collect();
259+
values.last().unwrap().kind().clone()
260+
}
261+
262+
fn type_statement(q: &str) -> Result<AstTypeMap<StaticType>, AstTransformationError> {
263+
let parsed = partiql_parser::Parser::default()
264+
.parse(q)
265+
.expect("Expect successful parse");
266+
267+
let catalog = PartiqlCatalog::default();
268+
let typer = AstStaticTyper::new(&catalog);
269+
if let ast::Expr::Query(q) = parsed.ast.as_ref() {
270+
typer.type_nodes(&q)
271+
} else {
272+
panic!("Typing statement other than `Query` are unsupported")
273+
}
274+
}
275+
}

partiql-ast/Cargo.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@ path = "src/lib.rs"
2020
bench = false
2121

2222
[dependencies]
23+
indexmap = "1.9"
2324
rust_decimal = { version = "1.25.0", default-features = false, features = ["std"] }
24-
2525
serde = { version = "1.*", features = ["derive"], optional = true }
2626

27-
2827
[dev-dependencies]
2928

30-
3129
[features]
3230
default = []
3331
serde = [

partiql-ast/src/ast.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// As more changes to this AST are expected, unless explicitly advised, using the structures exposed
99
// in this crate directly is not recommended.
1010

11+
use indexmap::IndexMap;
1112
use rust_decimal::Decimal as RustDecimal;
1213

1314
use std::fmt;
@@ -17,6 +18,8 @@ use serde::{Deserialize, Serialize};
1718

1819
use partiql_ast_macros::Visit;
1920

21+
pub type AstTypeMap<T> = IndexMap<NodeId, T>;
22+
2023
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
2124
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2225
pub struct NodeId(pub u32);

partiql-catalog/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ bench = false
2424
partiql-value = { path = "../partiql-value", version = "0.5.*" }
2525
partiql-parser = { path = "../partiql-parser", version = "0.5.*" }
2626
partiql-logical = { path = "../partiql-logical", version = "0.5.*" }
27+
2728
thiserror = "1.0"
2829
ordered-float = "3.*"
2930
itertools = "0.10.*"

partiql-logical-planner/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ partiql-ast = { path = "../partiql-ast", version = "0.5.*" }
2828
partiql-parser = { path = "../partiql-parser", version = "0.5.*" }
2929
partiql-catalog = { path = "../partiql-catalog", version = "0.5.*" }
3030
partiql-ast-passes = { path = "../partiql-ast-passes", version = "0.5.*" }
31+
3132
ion-rs = "0.18"
3233
ordered-float = "3.*"
3334
itertools = "0.10.*"

partiql-types/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ edition.workspace = true
2121
bench = false
2222

2323
[dependencies]
24+
2425
ordered-float = "3.*"
2526
itertools = "0.10.*"
2627
unicase = "2.6"

0 commit comments

Comments
 (0)