8
8
from shlex import quote , split
9
9
import sys
10
10
import time
11
- from types import MethodType
11
+ from types import MethodType , UnionType
12
12
from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple , TypeVar , Union , get_type_hints
13
13
from typing_inspect import is_literal_type , get_args
14
14
from warnings import warn
34
34
# Constants
35
35
EMPTY_TYPE = get_args (List )[0 ] if len (get_args (List )) > 0 else tuple ()
36
36
BOXED_COLLECTION_TYPES = {List , list , Set , set , Tuple , tuple }
37
- OPTIONAL_TYPES = {Optional , Union }
37
+ OPTIONAL_TYPES = {Optional , Union , UnionType }
38
38
BOXED_TYPES = BOXED_COLLECTION_TYPES | OPTIONAL_TYPES
39
39
40
40
@@ -170,12 +170,12 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
170
170
# If type is not explicitly provided, set it if it's one of our supported default types
171
171
if 'type' not in kwargs :
172
172
173
- # Unbox Optional[type] and set var_type = type
173
+ # Unbox Union[type] ( Optional[type]) and set var_type = type
174
174
if get_origin (var_type ) in OPTIONAL_TYPES :
175
175
var_args = get_args (var_type )
176
176
177
177
if len (var_args ) > 0 :
178
- var_type = get_args ( var_type ) [0 ]
178
+ var_type = var_args [0 ]
179
179
180
180
# If var_type is tuple as in Python 3.6, change to a typing type
181
181
# (e.g., (typing.List, <class 'bool'>) ==> typing.List[bool])
@@ -187,12 +187,14 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
187
187
# First check whether it is a literal type or a boxed literal type
188
188
if is_literal_type (var_type ):
189
189
var_type , kwargs ['choices' ] = get_literals (var_type , variable )
190
+
190
191
elif (get_origin (var_type ) in (List , list , Set , set )
191
192
and len (get_args (var_type )) > 0
192
193
and is_literal_type (get_args (var_type )[0 ])):
193
194
var_type , kwargs ['choices' ] = get_literals (get_args (var_type )[0 ], variable )
194
195
if kwargs .get ('action' ) not in {'append' , 'append_const' }:
195
196
kwargs ['nargs' ] = kwargs .get ('nargs' , '*' )
197
+
196
198
# Handle Tuple type (with type args) by extracting types of Tuple elements and enforcing them
197
199
elif get_origin (var_type ) in (Tuple , tuple ) and len (get_args (var_type )) > 0 :
198
200
loop = False
0 commit comments