Skip to content

Commit 4d04163

Browse files
committed
Added unit tests for new utils.get_types function
1 parent f8673cb commit 4d04163

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

cmd2/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,10 @@ def get_types(func_or_method: Callable[..., Any]) -> tuple[dict[str, Any], Any]:
12581258
:return tuple with first element being dictionary mapping param names to type hints
12591259
and second element being return type hint, unspecified, returns None
12601260
"""
1261-
type_hints = get_type_hints(func_or_method) # Get dictionary of type hints
1261+
try:
1262+
type_hints = get_type_hints(func_or_method) # Get dictionary of type hints
1263+
except TypeError as exc:
1264+
raise ValueError("Argument passed to get_types should be a function or method") from exc
12621265
ret_ann = type_hints.pop('return', None) # Pop off the return annotation if it exists
12631266
if inspect.ismethod(func_or_method):
12641267
type_hints.pop('self', None) # Pop off `self` hint for methods

tests/test_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,3 +892,44 @@ def custom_similarity_function(s1, s2) -> float:
892892

893893
suggested_command = cu.suggest_similar("test", ["test"], similarity_function_to_use=custom_similarity_function)
894894
assert suggested_command is None
895+
896+
897+
def test_get_types_invalid_input() -> None:
898+
x = 1
899+
with pytest.raises(ValueError, match="Argument passed to get_types should be a function or method"):
900+
cu.get_types(x)
901+
902+
903+
def test_get_types_empty() -> None:
904+
def a(b):
905+
print(b)
906+
907+
param_ann, ret_ann = cu.get_types(a)
908+
assert ret_ann is None
909+
assert param_ann == {}
910+
911+
912+
def test_get_types_non_empty() -> None:
913+
def foo(x: int) -> str:
914+
return f"{x * x}"
915+
916+
param_ann, ret_ann = cu.get_types(foo)
917+
assert ret_ann is str
918+
param_name, param_value = next(iter(param_ann.items()))
919+
assert param_name == 'x'
920+
assert param_value is int
921+
922+
923+
def test_get_types_method() -> None:
924+
class Foo:
925+
def bar(self, x: bool) -> None:
926+
print(x)
927+
928+
f = Foo()
929+
930+
param_ann, ret_ann = cu.get_types(f.bar)
931+
assert ret_ann is None
932+
assert len(param_ann) == 1
933+
param_name, param_value = next(iter(param_ann.items()))
934+
assert param_name == 'x'
935+
assert param_value is bool

0 commit comments

Comments
 (0)