Skip to content

Commit d799418

Browse files
Jessesaishreeeee
authored andcommitted
SQLAlchemy 2: Fix failing mypy checks from development (#257)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com> Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 3a8ef6d commit d799418

File tree

6 files changed

+63
-21
lines changed

6 files changed

+63
-21
lines changed

src/databricks/sqlalchemy/_parse.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
"""
1313

1414

15+
class DatabricksSqlAlchemyParseException(Exception):
16+
pass
17+
18+
1519
def _match_table_not_found_string(message: str) -> bool:
1620
"""Return True if the message contains a substring indicating that a table was not found"""
1721

@@ -31,7 +35,7 @@ def _describe_table_extended_result_to_dict_list(
3135
"""Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries"""
3236

3337
rows_to_return = []
34-
for row in result:
38+
for row in result.all():
3539
this_row = {"col_name": row.col_name, "data_type": row.data_type}
3640
rows_to_return.append(this_row)
3741

@@ -69,24 +73,33 @@ def extract_three_level_identifier_from_constraint_string(input_str: str) -> dic
6973
"schema": "pysql_dialect_compliance",
7074
"table": "users"
7175
}
76+
77+
Raise a DatabricksSqlAlchemyParseException if a 3L namespace isn't found
7278
"""
7379
pat = re.compile(r"REFERENCES\s+(.*?)\s*\(")
7480
matches = pat.findall(input_str)
7581

7682
if not matches:
77-
return None
83+
raise DatabricksSqlAlchemyParseException(
84+
"3L namespace not found in constraint string"
85+
)
7886

7987
first_match = matches[0]
8088
parts = first_match.split(".")
8189

8290
def strip_backticks(input: str):
8391
return input.replace("`", "")
8492

85-
return {
86-
"catalog": strip_backticks(parts[0]),
87-
"schema": strip_backticks(parts[1]),
88-
"table": strip_backticks(parts[2]),
89-
}
93+
try:
94+
return {
95+
"catalog": strip_backticks(parts[0]),
96+
"schema": strip_backticks(parts[1]),
97+
"table": strip_backticks(parts[2]),
98+
}
99+
except IndexError:
100+
raise DatabricksSqlAlchemyParseException(
101+
"Incomplete 3L namespace found in constraint string: " + ".".join(parts)
102+
)
90103

91104

92105
def _parse_fk_from_constraint_string(constraint_str: str) -> dict:
@@ -170,10 +183,12 @@ def build_fk_dict(
170183
else:
171184
schema_override_dict = {}
172185

186+
# mypy doesn't like this method of conditionally adding a key to a dictionary
187+
# while keeping everything immutable
173188
complete_foreign_key_dict = {
174189
"name": fk_name,
175190
**base_fk_dict,
176-
**schema_override_dict,
191+
**schema_override_dict, # type: ignore
177192
}
178193

179194
return complete_foreign_key_dict
@@ -234,7 +249,7 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis
234249
return output_rows
235250

236251

237-
def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]:
252+
def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]:
238253
"""If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
239254
one dictionary per defined constraint
240255
"""
@@ -307,7 +322,11 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu
307322
"""
308323

309324
pat = re.compile(r"^\w+")
310-
_raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower()
325+
326+
# This method assumes a valid TYPE_NAME field in the response.
327+
# TODO: add error handling in case TGetColumnsResponse format changes
328+
329+
_raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower() # type: ignore
311330
_col_type = GET_COLUMNS_TYPE_MAP[_raw_col_type]
312331

313332
if _raw_col_type == "decimal":
@@ -334,4 +353,5 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu
334353
"default": thrift_resp_row.COLUMN_DEF,
335354
}
336355

337-
return this_column
356+
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
357+
return this_column # type: ignore

src/databricks/sqlalchemy/base.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import Any, List, Optional, Dict, Collection, Iterable, Tuple
2+
from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple
33

44
import databricks.sqlalchemy._ddl as dialect_ddl_impl
55
import databricks.sqlalchemy._types as dialect_type_impl
@@ -73,10 +73,12 @@ class DatabricksDialect(default.DefaultDialect):
7373

7474
# SQLAlchemy requires that a table with no primary key
7575
# constraint return a dictionary that looks like this.
76-
EMPTY_PK = {"constrained_columns": [], "name": None}
76+
EMPTY_PK: Dict[str, Any] = {"constrained_columns": [], "name": None}
7777

7878
# SQLAlchemy requires that a table with no foreign keys
7979
# defined return an empty list. Same for indexes.
80+
EMPTY_FK: List
81+
EMPTY_INDEX: List
8082
EMPTY_FK = EMPTY_INDEX = []
8183

8284
@classmethod
@@ -139,7 +141,7 @@ def _describe_table_extended(
139141
catalog_name: Optional[str] = None,
140142
schema_name: Optional[str] = None,
141143
expect_result=True,
142-
) -> List[Dict[str, str]]:
144+
) -> Union[List[Dict[str, str]], None]:
143145
"""Run DESCRIBE TABLE EXTENDED on a table and return a list of dictionaries of the result.
144146
145147
This method is the fastest way to check for the presence of a table in a schema.
@@ -158,7 +160,7 @@ def _describe_table_extended(
158160
stmt = DDL(f"DESCRIBE TABLE EXTENDED {_target}")
159161

160162
try:
161-
result = connection.execute(stmt).all()
163+
result = connection.execute(stmt)
162164
except DatabaseError as e:
163165
if _match_table_not_found_string(str(e)):
164166
raise sqlalchemy.exc.NoSuchTableError(
@@ -197,9 +199,11 @@ def get_pk_constraint(
197199
schema_name=schema,
198200
)
199201

200-
raw_pk_constraints: List = get_pk_strings_from_dte_output(result)
202+
# Type ignore is because mypy knows that self._describe_table_extended *can*
203+
# return None (even though it never will since expect_result defaults to True)
204+
raw_pk_constraints: List = get_pk_strings_from_dte_output(result) # type: ignore
201205
if not any(raw_pk_constraints):
202-
return self.EMPTY_PK
206+
return self.EMPTY_PK # type: ignore
203207

204208
if len(raw_pk_constraints) > 1:
205209
logger.warning(
@@ -212,11 +216,12 @@ def get_pk_constraint(
212216
pk_name = first_pk_constraint.get("col_name")
213217
pk_constraint_string = first_pk_constraint.get("data_type")
214218

215-
return build_pk_dict(pk_name, pk_constraint_string)
219+
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
220+
return build_pk_dict(pk_name, pk_constraint_string) # type: ignore
216221

217222
def get_foreign_keys(
218223
self, connection, table_name, schema=None, **kw
219-
) -> ReflectedForeignKeyConstraint:
224+
) -> List[ReflectedForeignKeyConstraint]:
220225
"""Return information about foreign_keys in `table_name`."""
221226

222227
result = self._describe_table_extended(
@@ -225,7 +230,9 @@ def get_foreign_keys(
225230
schema_name=schema,
226231
)
227232

228-
raw_fk_constraints: List = get_fk_strings_from_dte_output(result)
233+
# Type ignore is because mypy knows that self._describe_table_extended *can*
234+
# return None (even though it never will since expect_result defaults to True)
235+
raw_fk_constraints: List = get_fk_strings_from_dte_output(result) # type: ignore
229236

230237
if not any(raw_fk_constraints):
231238
return self.EMPTY_FK
@@ -239,7 +246,8 @@ def get_foreign_keys(
239246
)
240247
fk_constraints.append(this_constraint_dict)
241248

242-
return fk_constraints
249+
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
250+
return fk_constraints # type: ignore
243251

244252
def get_indexes(self, connection, table_name, schema=None, **kw):
245253
"""SQLAlchemy requires this method. Databricks doesn't support indexes."""

src/databricks/sqlalchemy/test/_future.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# type: ignore
2+
13
from enum import Enum
24

35
import pytest

src/databricks/sqlalchemy/test/_regression.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# type: ignore
2+
13
import pytest
24
from sqlalchemy.testing.suite import (
35
ArgSignatureTest,

src/databricks/sqlalchemy/test/_unsupported.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# type: ignore
2+
13
from enum import Enum
24

35
import pytest

src/databricks/sqlalchemy/test_local/test_parsing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
build_fk_dict,
77
build_pk_dict,
88
match_dte_rows_by_value,
9+
DatabricksSqlAlchemyParseException,
910
)
1011

1112

@@ -55,6 +56,13 @@ def test_extract_3l_namespace_from_constraint_string():
5556
), "Failed to extract 3L namespace from constraint string"
5657

5758

59+
def test_extract_3l_namespace_from_bad_constraint_string():
60+
input = "FOREIGN KEY (`parent_user_id`) REFERENCES `pysql_dialect_compliance`.`users` (`user_id`)"
61+
62+
with pytest.raises(DatabricksSqlAlchemyParseException):
63+
extract_three_level_identifier_from_constraint_string(input)
64+
65+
5866
@pytest.mark.parametrize("schema", [None, "some_schema"])
5967
def test_build_fk_dict(schema):
6068
fk_constraint_string = "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`some_schema`.`users` (`user_id`)"

0 commit comments

Comments
 (0)