Skip to content

Commit 7a8b308

Browse files
committed
Beginning to support 310 union types. See Issue 64 for more detail.
1 parent 8cd23fb commit 7a8b308

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

tap/tap.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from shlex import quote, split
99
import sys
1010
import time
11-
from types import MethodType
11+
from types import MethodType, UnionType
1212
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints
1313
from typing_inspect import is_literal_type, get_args
1414
from warnings import warn
@@ -34,7 +34,7 @@
3434
# Constants
3535
EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple()
3636
BOXED_COLLECTION_TYPES = {List, list, Set, set, Tuple, tuple}
37-
OPTIONAL_TYPES = {Optional, Union}
37+
OPTIONAL_TYPES = {Optional, Union, UnionType}
3838
BOXED_TYPES = BOXED_COLLECTION_TYPES | OPTIONAL_TYPES
3939

4040

@@ -170,12 +170,12 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
170170
# If type is not explicitly provided, set it if it's one of our supported default types
171171
if 'type' not in kwargs:
172172

173-
# Unbox Optional[type] and set var_type = type
173+
# Unbox Union[type] (Optional[type]) and set var_type = type
174174
if get_origin(var_type) in OPTIONAL_TYPES:
175175
var_args = get_args(var_type)
176176

177177
if len(var_args) > 0:
178-
var_type = get_args(var_type)[0]
178+
var_type = var_args[0]
179179

180180
# If var_type is tuple as in Python 3.6, change to a typing type
181181
# (e.g., (typing.List, <class 'bool'>) ==> typing.List[bool])
@@ -187,12 +187,14 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
187187
# First check whether it is a literal type or a boxed literal type
188188
if is_literal_type(var_type):
189189
var_type, kwargs['choices'] = get_literals(var_type, variable)
190+
190191
elif (get_origin(var_type) in (List, list, Set, set)
191192
and len(get_args(var_type)) > 0
192193
and is_literal_type(get_args(var_type)[0])):
193194
var_type, kwargs['choices'] = get_literals(get_args(var_type)[0], variable)
194195
if kwargs.get('action') not in {'append', 'append_const'}:
195196
kwargs['nargs'] = kwargs.get('nargs', '*')
197+
196198
# Handle Tuple type (with type args) by extracting types of Tuple elements and enforcing them
197199
elif get_origin(var_type) in (Tuple, tuple) and len(get_args(var_type)) > 0:
198200
loop = False

tap/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import subprocess
1313
import sys
1414
import tokenize
15+
from types import UnionType
1516
from typing import (
1617
Any,
1718
Callable,
@@ -471,7 +472,7 @@ def enforce_reproducibility(saved_reproducibility_data: Optional[Dict[str, str]]
471472
f'in current args.')
472473

473474

474-
# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.8 and 3.9
475+
# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.8, 3.9, and 3.10
475476
# https://github.com/ilevkivskyi/typing_inspect/issues/64
476477
# https://github.com/ilevkivskyi/typing_inspect/issues/65
477478
def get_origin(tp: Any) -> Any:
@@ -481,4 +482,7 @@ def get_origin(tp: Any) -> Any:
481482
if origin is None:
482483
origin = tp
483484

485+
if isinstance(origin, UnionType):
486+
origin = UnionType
487+
484488
return origin

0 commit comments

Comments
 (0)