|
| 1 | +use crate::error::{AstTransformError, AstTransformationError}; |
| 2 | +use partiql_ast::ast::{ |
| 3 | + AstNode, AstTypeMap, Bag, Expr, List, Lit, NodeId, Query, QuerySet, Struct, |
| 4 | +}; |
| 5 | +use partiql_ast::visit::{Traverse, Visit, Visitor}; |
| 6 | +use partiql_catalog::Catalog; |
| 7 | +use partiql_types::{ArrayType, BagType, StaticType, StaticTypeKind, StructType}; |
| 8 | + |
| 9 | +#[derive(Debug, Clone)] |
| 10 | +#[allow(dead_code)] |
| 11 | +pub struct AstStaticTyper<'c> { |
| 12 | + id_stack: Vec<NodeId>, |
| 13 | + container_stack: Vec<Vec<StaticType>>, |
| 14 | + errors: Vec<AstTransformError>, |
| 15 | + type_map: AstTypeMap<StaticType>, |
| 16 | + catalog: &'c dyn Catalog, |
| 17 | +} |
| 18 | + |
| 19 | +impl<'c> AstStaticTyper<'c> { |
| 20 | + pub fn new(catalog: &'c dyn Catalog) -> Self { |
| 21 | + AstStaticTyper { |
| 22 | + id_stack: Default::default(), |
| 23 | + container_stack: Default::default(), |
| 24 | + errors: Default::default(), |
| 25 | + type_map: Default::default(), |
| 26 | + catalog, |
| 27 | + } |
| 28 | + } |
| 29 | + |
| 30 | + pub fn type_nodes( |
| 31 | + mut self, |
| 32 | + query: &AstNode<Query>, |
| 33 | + ) -> Result<AstTypeMap<StaticType>, AstTransformationError> { |
| 34 | + query.visit(&mut self); |
| 35 | + if self.errors.is_empty() { |
| 36 | + Ok(self.type_map) |
| 37 | + } else { |
| 38 | + Err(AstTransformationError { |
| 39 | + errors: self.errors, |
| 40 | + }) |
| 41 | + } |
| 42 | + } |
| 43 | + |
| 44 | + #[inline] |
| 45 | + fn current_node(&self) -> &NodeId { |
| 46 | + self.id_stack.last().unwrap() |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +impl<'c, 'ast> Visitor<'ast> for AstStaticTyper<'c> { |
| 51 | + fn enter_ast_node(&mut self, id: NodeId) -> Traverse { |
| 52 | + self.id_stack.push(id); |
| 53 | + Traverse::Continue |
| 54 | + } |
| 55 | + |
| 56 | + fn exit_ast_node(&mut self, id: NodeId) -> Traverse { |
| 57 | + assert_eq!(self.id_stack.pop(), Some(id)); |
| 58 | + Traverse::Continue |
| 59 | + } |
| 60 | + |
| 61 | + fn enter_query(&mut self, _query: &'ast Query) -> Traverse { |
| 62 | + Traverse::Continue |
| 63 | + } |
| 64 | + |
| 65 | + fn exit_query(&mut self, _query: &'ast Query) -> Traverse { |
| 66 | + Traverse::Continue |
| 67 | + } |
| 68 | + |
| 69 | + fn enter_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse { |
| 70 | + match _query_set { |
| 71 | + QuerySet::SetOp(_) => { |
| 72 | + todo!() |
| 73 | + } |
| 74 | + QuerySet::Select(_) => {} |
| 75 | + QuerySet::Expr(_) => {} |
| 76 | + QuerySet::Values(_) => { |
| 77 | + todo!() |
| 78 | + } |
| 79 | + QuerySet::Table(_) => { |
| 80 | + todo!() |
| 81 | + } |
| 82 | + } |
| 83 | + Traverse::Continue |
| 84 | + } |
| 85 | + |
| 86 | + fn exit_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse { |
| 87 | + Traverse::Continue |
| 88 | + } |
| 89 | + |
| 90 | + fn enter_expr(&mut self, _expr: &'ast Expr) -> Traverse { |
| 91 | + Traverse::Continue |
| 92 | + } |
| 93 | + |
| 94 | + fn exit_expr(&mut self, _expr: &'ast Expr) -> Traverse { |
| 95 | + Traverse::Continue |
| 96 | + } |
| 97 | + |
| 98 | + fn enter_lit(&mut self, _lit: &'ast Lit) -> Traverse { |
| 99 | + // Currently we're assuming no-schema, hence typing to arbitrary sized scalars. |
| 100 | + // TODO type to the corresponding scalar with the introduction of schema |
| 101 | + let kind = match _lit { |
| 102 | + Lit::Null => StaticTypeKind::Null, |
| 103 | + Lit::Missing => StaticTypeKind::Missing, |
| 104 | + Lit::Int8Lit(_) => StaticTypeKind::Int, |
| 105 | + Lit::Int16Lit(_) => StaticTypeKind::Int, |
| 106 | + Lit::Int32Lit(_) => StaticTypeKind::Int, |
| 107 | + Lit::Int64Lit(_) => StaticTypeKind::Int, |
| 108 | + Lit::DecimalLit(_) => StaticTypeKind::Decimal, |
| 109 | + Lit::NumericLit(_) => StaticTypeKind::Decimal, |
| 110 | + Lit::RealLit(_) => StaticTypeKind::Float64, |
| 111 | + Lit::FloatLit(_) => StaticTypeKind::Float64, |
| 112 | + Lit::DoubleLit(_) => StaticTypeKind::Float64, |
| 113 | + Lit::BoolLit(_) => StaticTypeKind::Bool, |
| 114 | + Lit::IonStringLit(_) => todo!(), |
| 115 | + Lit::CharStringLit(_) => StaticTypeKind::String, |
| 116 | + Lit::NationalCharStringLit(_) => StaticTypeKind::String, |
| 117 | + Lit::BitStringLit(_) => todo!(), |
| 118 | + Lit::HexStringLit(_) => todo!(), |
| 119 | + Lit::StructLit(_) => StaticTypeKind::Struct(StructType::unconstrained()), |
| 120 | + Lit::ListLit(_) => StaticTypeKind::Array(ArrayType::array()), |
| 121 | + Lit::BagLit(_) => StaticTypeKind::Bag(BagType::bag()), |
| 122 | + Lit::TypedLit(_, _) => todo!(), |
| 123 | + }; |
| 124 | + |
| 125 | + let ty = StaticType::new(kind); |
| 126 | + let id = *self.current_node(); |
| 127 | + if let Some(c) = self.container_stack.last_mut() { |
| 128 | + c.push(ty.clone()) |
| 129 | + } |
| 130 | + self.type_map.insert(id, ty); |
| 131 | + Traverse::Continue |
| 132 | + } |
| 133 | + |
| 134 | + fn enter_struct(&mut self, _struct: &'ast Struct) -> Traverse { |
| 135 | + self.container_stack.push(vec![]); |
| 136 | + Traverse::Continue |
| 137 | + } |
| 138 | + |
| 139 | + fn exit_struct(&mut self, _struct: &'ast Struct) -> Traverse { |
| 140 | + let id = *self.current_node(); |
| 141 | + let fields = self.container_stack.pop(); |
| 142 | + |
| 143 | + // Such type checking will very likely move to a common module |
| 144 | + // TODO move to a more appropriate place for re-use. |
| 145 | + if let Some(f) = fields { |
| 146 | + // We already fail during parsing if the struct has wrong number of key-value pairs, e.g.: |
| 147 | + // {'a', 1, 'b'} |
| 148 | + // However, adding this check here. |
| 149 | + let is_malformed = f.len() % 2 > 0; |
| 150 | + if is_malformed { |
| 151 | + self.errors.push(AstTransformError::IllegalState( |
| 152 | + "Struct key-value pairs are malformed".to_string(), |
| 153 | + )); |
| 154 | + } |
| 155 | + |
| 156 | + let has_invalid_keys = f.chunks(2).map(|t| &t[0]).any(|t| !t.is_string()); |
| 157 | + if has_invalid_keys || is_malformed { |
| 158 | + self.errors.push(AstTransformError::IllegalState( |
| 159 | + "Struct keys can only resolve to `String` type".to_string(), |
| 160 | + )); |
| 161 | + } |
| 162 | + } |
| 163 | + |
| 164 | + let ty = StaticType::new_struct(StructType::unconstrained()); |
| 165 | + self.type_map.insert(id, ty.clone()); |
| 166 | + |
| 167 | + if let Some(c) = self.container_stack.last_mut() { |
| 168 | + c.push(ty) |
| 169 | + } |
| 170 | + |
| 171 | + Traverse::Continue |
| 172 | + } |
| 173 | + |
| 174 | + fn enter_bag(&mut self, _bag: &'ast Bag) -> Traverse { |
| 175 | + self.container_stack.push(vec![]); |
| 176 | + Traverse::Continue |
| 177 | + } |
| 178 | + |
| 179 | + fn exit_bag(&mut self, _bag: &'ast Bag) -> Traverse { |
| 180 | + // TODO add schema validation of BAG elements, e.g. for Schema Bag<Int> if there is at least |
| 181 | + // one element that isn't INT there is a type checking error. |
| 182 | + |
| 183 | + // TODO clarify if we need to record the internal types of bag literal or stick w/Schema? |
| 184 | + self.container_stack.pop(); |
| 185 | + |
| 186 | + let id = *self.current_node(); |
| 187 | + let ty = StaticType::new_bag(BagType::bag()); |
| 188 | + |
| 189 | + self.type_map.insert(id, ty.clone()); |
| 190 | + if let Some(s) = self.container_stack.last_mut() { |
| 191 | + s.push(ty) |
| 192 | + } |
| 193 | + Traverse::Continue |
| 194 | + } |
| 195 | + |
| 196 | + fn enter_list(&mut self, _list: &'ast List) -> Traverse { |
| 197 | + self.container_stack.push(vec![]); |
| 198 | + Traverse::Continue |
| 199 | + } |
| 200 | + |
| 201 | + fn exit_list(&mut self, _list: &'ast List) -> Traverse { |
| 202 | + // TODO clarify if we need to record the internal types of array literal or stick w/Schema? |
| 203 | + // one element that isn't INT there is a type checking error. |
| 204 | + |
| 205 | + // TODO clarify if we need to record the internal types of array literal or stick w/Schema? |
| 206 | + self.container_stack.pop(); |
| 207 | + |
| 208 | + let id = *self.current_node(); |
| 209 | + let ty = StaticType::new_array(ArrayType::array()); |
| 210 | + |
| 211 | + self.type_map.insert(id, ty.clone()); |
| 212 | + if let Some(s) = self.container_stack.last_mut() { |
| 213 | + s.push(ty) |
| 214 | + } |
| 215 | + Traverse::Continue |
| 216 | + } |
| 217 | +} |
| 218 | + |
| 219 | +#[cfg(test)] |
| 220 | +mod tests { |
| 221 | + use super::*; |
| 222 | + use assert_matches::assert_matches; |
| 223 | + use partiql_ast::ast; |
| 224 | + use partiql_catalog::PartiqlCatalog; |
| 225 | + use partiql_types::{StaticType, StaticTypeKind}; |
| 226 | + |
| 227 | + #[test] |
| 228 | + fn simple_test() { |
| 229 | + assert_matches!(run_literal_test("NULL"), StaticTypeKind::Null); |
| 230 | + assert_matches!(run_literal_test("MISSING"), StaticTypeKind::Missing); |
| 231 | + assert_matches!(run_literal_test("Missing"), StaticTypeKind::Missing); |
| 232 | + assert_matches!(run_literal_test("true"), StaticTypeKind::Bool); |
| 233 | + assert_matches!(run_literal_test("false"), StaticTypeKind::Bool); |
| 234 | + assert_matches!(run_literal_test("1"), StaticTypeKind::Int); |
| 235 | + assert_matches!(run_literal_test("1.5"), StaticTypeKind::Decimal); |
| 236 | + assert_matches!(run_literal_test("'hello world!'"), StaticTypeKind::String); |
| 237 | + assert_matches!( |
| 238 | + run_literal_test("[1, 2 , {'a': 2}]"), |
| 239 | + StaticTypeKind::Array(_) |
| 240 | + ); |
| 241 | + assert_matches!( |
| 242 | + run_literal_test("<<'1', {'a': 11}>>"), |
| 243 | + StaticTypeKind::Bag(_) |
| 244 | + ); |
| 245 | + assert_matches!( |
| 246 | + run_literal_test("{'a': 1, 'b': 3, 'c': [1, 2]}"), |
| 247 | + StaticTypeKind::Struct(_) |
| 248 | + ); |
| 249 | + } |
| 250 | + |
| 251 | + #[test] |
| 252 | + fn simple_err_test() { |
| 253 | + assert!(type_statement("{'a': 1, a.b: 3}").is_err()); |
| 254 | + } |
| 255 | + |
| 256 | + fn run_literal_test(q: &str) -> StaticTypeKind { |
| 257 | + let out = type_statement(q).expect("type map"); |
| 258 | + let values: Vec<&StaticType> = out.values().collect(); |
| 259 | + values.last().unwrap().kind().clone() |
| 260 | + } |
| 261 | + |
| 262 | + fn type_statement(q: &str) -> Result<AstTypeMap<StaticType>, AstTransformationError> { |
| 263 | + let parsed = partiql_parser::Parser::default() |
| 264 | + .parse(q) |
| 265 | + .expect("Expect successful parse"); |
| 266 | + |
| 267 | + let catalog = PartiqlCatalog::default(); |
| 268 | + let typer = AstStaticTyper::new(&catalog); |
| 269 | + if let ast::Expr::Query(q) = parsed.ast.as_ref() { |
| 270 | + typer.type_nodes(&q) |
| 271 | + } else { |
| 272 | + panic!("Typing statement other than `Query` are unsupported") |
| 273 | + } |
| 274 | + } |
| 275 | +} |
0 commit comments