Skip to content

Commit fc4eea2

Browse files
committed
[cdd/shared/ast_utils.py] Fix Literal support in get_types ; remove unnecessary type restriction in infer_imports and fix its type_comment implementation ; [cdd/tests/test_shared/test_ast_utils.py] Increase test coverage of this file to 100% ; [cdd/__init__.py] Bump version
1 parent 04b7ad4 commit fc4eea2

File tree

3 files changed

+146
-19
lines changed

3 files changed

+146
-19
lines changed

cdd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from logging import getLogger as get_logger
1010

1111
__author__ = "Samuel Marks" # type: str
12-
__version__ = "0.0.99rc44" # type: str
12+
__version__ = "0.0.99rc45" # type: str
1313
__description__ = (
1414
"Open API to/fro routes, models, and tests. "
1515
"Convert between docstrings, classes, methods, argparse, pydantic, and SQLalchemy."

cdd/shared/ast_utils.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
from json import dumps
4747
from operator import attrgetter, contains, inv, itemgetter, neg, not_, pos
4848
from os import path
49-
from typing import FrozenSet, Generator, Optional
49+
from typing import Callable, FrozenSet, Generator, MutableSet, Optional
50+
from typing import Tuple as TTuple
5051
from typing import __all__ as typing__all__
5152

5253
import cdd.shared.source_transformer
@@ -2215,7 +2216,14 @@ def get_types(node):
22152216
return iter((node.value.id, node.slice.id))
22162217
elif isinstance(node.slice, Tuple):
22172218
return chain.from_iterable(
2218-
((node.value.id,), map(get_value, map(get_value, node.slice.elts)))
2219+
(
2220+
(node.value.id,),
2221+
(
2222+
iter(())
2223+
if node.value.id == "Literal"
2224+
else map(get_value, map(get_value, node.slice.elts))
2225+
),
2226+
)
22192227
)
22202228

22212229

@@ -2228,16 +2236,16 @@ def infer_imports(module, modules_to_all=DEFAULT_MODULES_TO_ALL):
22282236
- sqlalchemy
22292237
- pydantic
22302238
2231-
:param module: Module, ClassDef, FunctionDef, AsyncFunctionDef, Assign
2232-
:type module: ```Union[ClassDef, FunctionDef, AsyncFunctionDef, Assign]```
2239+
:param module: Module, ClassDef, FunctionDef, AsyncFunctionDef, Assign, AnnAssign
2240+
:type module: ```Union[Module, ClassDef, FunctionDef, AsyncFunctionDef, Assign, AnnAssign]```
22332241
22342242
:param modules_to_all: Tuple of module_name to __all__ of module; (str) to FrozenSet[str]
22352243
:type modules_to_all: ```tuple[tuple[str, frozenset], ...]```
22362244
22372245
:return: List of imports
2238-
:rtype: ```Optional[Tuple[Union[Import, ImportFrom]]]```
2246+
:rtype: ```Optional[Tuple[Union[Import, ImportFrom], ...]]```
22392247
"""
2240-
if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef, Assign)):
2248+
if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef, Assign, AnnAssign)):
22412249
module: Module = Module(body=[module], type_ignores=[], stmt=None)
22422250
assert isinstance(module, Module), "Expected `Module` got `{type_name}`".format(
22432251
type_name=type(module).__name__
@@ -2252,7 +2260,13 @@ def node_to_importable_name(node):
22522260
:rtype: ```Optional[str]```
22532261
"""
22542262
if getattr(node, "type_comment", None) is not None:
2255-
return node.type_comment
2263+
return (
2264+
node.type_comment
2265+
if node.type_comment in simple_types
2266+
else get_value(
2267+
get_value(get_value(ast.parse(node.type_comment).body[0]))
2268+
)
2269+
)
22562270
elif getattr(node, "annotation", None) is not None:
22572271
node = node # type: Union[AnnAssign, arg]
22582272
return node.annotation # cast(node, Union[AnnAssign, arg])
@@ -2261,7 +2275,9 @@ def node_to_importable_name(node):
22612275
else:
22622276
return None
22632277

2264-
_symbol_to_import = partial(symbol_to_import, modules_to_all=modules_to_all)
2278+
_symbol_to_import: Callable[[str], Optional[TTuple[str, str]]] = partial(
2279+
symbol_to_import, modules_to_all=modules_to_all
2280+
)
22652281

22662282
# Lots of room for optimisation here; but its probably NP-hard:
22672283
imports = tuple(
@@ -2352,8 +2368,10 @@ def deduplicate_sorted_imports(module):
23522368
:return: Module but with duplicate import entries in first import block removed
23532369
:rtype: ```Module```
23542370
"""
2355-
assert isinstance(module, Module)
2356-
fst_import_idx = next(
2371+
assert isinstance(module, Module), "Expected `Module` got `{}`".format(
2372+
type(module).__name__
2373+
)
2374+
fst_import_idx: Optional[int] = next(
23572375
map(
23582376
itemgetter(0),
23592377
filter(
@@ -2365,7 +2383,7 @@ def deduplicate_sorted_imports(module):
23652383
)
23662384
if fst_import_idx is None:
23672385
return module
2368-
lst_import_idx = next(
2386+
lst_import_idx: Optional[int] = next(
23692387
iter(
23702388
deque(
23712389
map(
@@ -2380,7 +2398,7 @@ def deduplicate_sorted_imports(module):
23802398
),
23812399
None,
23822400
)
2383-
name_seen = set()
2401+
name_seen: MutableSet[str] = set()
23842402

23852403
module.body = (
23862404
module.body[:fst_import_idx]

cdd/tests/test_shared/test_ast_utils.py

Lines changed: 115 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
arguments,
3232
keyword,
3333
)
34+
from collections import deque
3435
from copy import deepcopy
3536
from itertools import repeat
3637
from os import extsep, path
@@ -88,6 +89,7 @@
8889
function_adder_ast,
8990
function_adder_str,
9091
)
92+
from cdd.tests.mocks.pydantic import pydantic_class_cls_def
9193
from cdd.tests.mocks.sqlalchemy import config_decl_base_ast
9294
from cdd.tests.utils_for_tests import inspectable_compile, run_ast_test, unittest_main
9395

@@ -432,7 +434,7 @@ def test_infer_imports_with_sqlalchemy(self) -> None:
432434
"""
433435
imports = infer_imports(
434436
config_decl_base_ast
435-
) # type: Optional[Tuple[Union[Import, ImportFrom]]]
437+
) # type: Optional[TTuple[Union[Import, ImportFrom], ...]]
436438
self.assertIsNotNone(imports)
437439
self.assertEqual(len(imports), 1)
438440
run_ast_test(
@@ -455,6 +457,57 @@ def test_infer_imports_with_sqlalchemy(self) -> None:
455457
),
456458
)
457459

460+
def test_infer_imports_with_simple_node_variants(self) -> None:
461+
"""
462+
Test that `infer_imports` with some simple variants
463+
"""
464+
465+
def inner_test(imports):
466+
"""
467+
Run the actual test
468+
469+
:param imports: The imports to compare against
470+
:type imports: ```TList[ImportFrom]```
471+
"""
472+
self.assertIsNotNone(imports)
473+
self.assertEqual(len(imports), 1)
474+
run_ast_test(
475+
self,
476+
imports[0],
477+
ImportFrom(
478+
module="typing" if PY_GTE_3_8 else "typing_extensions",
479+
names=[
480+
alias(
481+
"Literal",
482+
None,
483+
identifier=None,
484+
identifier_name=None,
485+
)
486+
],
487+
level=0,
488+
),
489+
)
490+
491+
deque(
492+
map(
493+
inner_test,
494+
map(
495+
infer_imports,
496+
(
497+
pydantic_class_cls_def,
498+
Assign(
499+
targets=[Name("a", Load(), lineno=None, col_offset=None)],
500+
value=set_value("cat"),
501+
type_comment="Literal['cat']",
502+
expr=None,
503+
lineno=None,
504+
),
505+
),
506+
),
507+
),
508+
maxlen=0,
509+
)
510+
458511
def test_node_to_dict(self) -> None:
459512
"""
460513
Tests `node_to_dict`
@@ -642,6 +695,7 @@ def test_get_value(self) -> None:
642695
)
643696
self.assertIsNone(get_value(Name(None, None)))
644697
self.assertEqual(get_value(get_value(ast.parse("-5").body[0])), -5)
698+
self.assertEqual(get_value(Num(n=-5, constant_value=None, string=None)), -5)
645699

646700
def test_set_value(self) -> None:
647701
"""Tests that `set_value` returns the right type for the right Python version"""
@@ -749,21 +803,76 @@ def test_find_ast_type_fails(self) -> None:
749803

750804
def test_get_types(self) -> None:
751805
"""Test that `get_types` functions correctly"""
806+
self.assertTupleEqual(tuple(get_types(None)), tuple())
807+
self.assertTupleEqual(tuple(get_types("str")), ("str",))
808+
self.assertTupleEqual(
809+
tuple(
810+
get_types(
811+
Subscript(
812+
value=Name(
813+
id="Optional", ctx=Load(), lineno=None, col_offset=None
814+
),
815+
slice=Name(id="Any", ctx=Load(), lineno=None, col_offset=None),
816+
ctx=Load(),
817+
expr_context_ctx=None,
818+
expr_slice=None,
819+
expr_value=None,
820+
lineno=None,
821+
col_offset=None,
822+
)
823+
)
824+
),
825+
("Optional", "Any"),
826+
)
752827
self.assertTupleEqual(
753-
tuple(get_types("str")),
754-
("str",),
828+
tuple(
829+
get_types(
830+
Subscript(
831+
value=Name(
832+
id="Literal", ctx=Load(), lineno=None, col_offset=None
833+
),
834+
slice=Tuple(
835+
elts=list(map(set_value, ("foo", "bar"))),
836+
ctx=Load(),
837+
expr=None,
838+
lineno=None,
839+
col_offset=None,
840+
),
841+
ctx=Load(),
842+
expr_context_ctx=None,
843+
expr_slice=None,
844+
expr_value=None,
845+
lineno=None,
846+
col_offset=None,
847+
)
848+
)
849+
),
850+
("Literal",),
755851
)
756852
self.assertTupleEqual(
757853
tuple(
758854
get_types(
759855
Subscript(
760-
value=Name(id="Optional", ctx=Load()),
761-
slice=Name(id="Any", ctx=Load()),
856+
value=Name(
857+
id="Tuple", ctx=Load(), lineno=None, col_offset=None
858+
),
859+
slice=Tuple(
860+
elts=list(map(set_value, ("int", "float"))),
861+
ctx=Load(),
862+
expr=None,
863+
lineno=None,
864+
col_offset=None,
865+
),
762866
ctx=Load(),
867+
expr_context_ctx=None,
868+
expr_slice=None,
869+
expr_value=None,
870+
lineno=None,
871+
col_offset=None,
763872
)
764873
)
765874
),
766-
("Optional", "Any"),
875+
("Tuple", "int", "float"),
767876
)
768877

769878
def test_to_named_class_def(self) -> None:

0 commit comments

Comments
 (0)