1
1
from argparse import ArgumentParser , ArgumentTypeError
2
+ import ast
2
3
from base64 import b64encode , b64decode
3
4
import copy
4
5
from functools import wraps
10
11
import re
11
12
import subprocess
12
13
import sys
14
+ import textwrap
13
15
import tokenize
14
16
from typing import (
15
17
Any ,
16
18
Callable ,
17
19
Dict ,
18
20
Generator ,
21
+ Iterable ,
19
22
Iterator ,
20
23
List ,
21
24
Literal ,
22
25
Optional ,
26
+ Set ,
23
27
Tuple ,
24
28
Union ,
25
29
)
26
30
from typing_inspect import get_args as typing_inspect_get_args , get_origin as typing_inspect_get_origin
31
+ import warnings
27
32
28
33
if sys .version_info >= (3 , 10 ):
29
34
from types import UnionType
@@ -162,7 +167,7 @@ def get_argument_name(*name_or_flags) -> str:
162
167
return "help"
163
168
164
169
if len (name_or_flags ) > 1 :
165
- name_or_flags = [ n_or_f for n_or_f in name_or_flags if n_or_f .startswith ("--" )]
170
+ name_or_flags = tuple ( n_or_f for n_or_f in name_or_flags if n_or_f .startswith ("--" ))
166
171
167
172
if len (name_or_flags ) != 1 :
168
173
raise ValueError (f"There should only be a single canonical name for argument { name_or_flags } !" )
@@ -201,30 +206,28 @@ def is_positional_arg(*name_or_flags) -> bool:
201
206
return not is_option_arg (* name_or_flags )
202
207
203
208
204
- def tokenize_source (obj : object ) -> Generator :
205
- """Returns a generator for the tokens of the object's source code."""
206
- source = inspect .getsource (obj )
207
- token_generator = tokenize .generate_tokens (StringIO (source ).readline )
209
+ def tokenize_source (source : str ) -> Generator [tokenize .TokenInfo , None , None ]:
210
+ """Returns a generator for the tokens of the object's source code, given the source code."""
211
+ return tokenize .generate_tokens (StringIO (source ).readline )
208
212
209
- return token_generator
210
213
211
-
212
- def get_class_column (obj : type ) -> int :
213
- """Determines the column number for class variables in a class."""
214
+ def get_class_column (tokens : Iterable [tokenize .TokenInfo ]) -> int :
215
+ """Determines the column number for class variables in a class, given the tokens of the class."""
214
216
first_line = 1
215
- for token_type , token , (start_line , start_column ), (end_line , end_column ), line in tokenize_source ( obj ) :
217
+ for token_type , token , (start_line , start_column ), (end_line , end_column ), line in tokens :
216
218
if token .strip () == "@" :
217
219
first_line += 1
218
220
if start_line <= first_line or token .strip () == "" :
219
221
continue
220
222
221
223
return start_column
224
+ raise ValueError ("Could not find any class variables in the class." )
222
225
223
226
224
- def source_line_to_tokens (obj : object ) -> Dict [int , List [Dict [str , Union [str , int ]]]]:
225
- """Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code ."""
227
+ def source_line_to_tokens (tokens : Iterable [ tokenize . TokenInfo ] ) -> Dict [int , List [Dict [str , Union [str , int ]]]]:
228
+ """Extract a map from each line number to list of mappings providing information about each token ."""
226
229
line_to_tokens = {}
227
- for token_type , token , (start_line , start_column ), (end_line , end_column ), line in tokenize_source ( obj ) :
230
+ for token_type , token , (start_line , start_column ), (end_line , end_column ), line in tokens :
228
231
line_to_tokens .setdefault (start_line , []).append (
229
232
{
230
233
"token_type" : token_type ,
@@ -240,20 +243,98 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in
240
243
return line_to_tokens
241
244
242
245
246
+ def get_subsequent_assign_lines (source_cls : str ) -> Tuple [Set [int ], Set [int ]]:
247
+ """For all multiline assign statements, get the line numbers after the first line in the assignment.
248
+
249
+ :param source_cls: The source code of the class.
250
+ :return: A set of intermediate line numbers for multiline assign statements and a set of final line numbers.
251
+ """
252
+ # Parse source code using ast (with an if statement to avoid indentation errors)
253
+ source = f"if True:\n { textwrap .indent (source_cls , ' ' )} "
254
+ body = ast .parse (source ).body [0 ]
255
+
256
+ # Set up warning message
257
+ parse_warning = (
258
+ "Could not parse class source code to extract comments. Comments in the help string may be incorrect."
259
+ )
260
+
261
+ # Check for correct parsing
262
+ if not isinstance (body , ast .If ):
263
+ warnings .warn (parse_warning )
264
+ return set (), set ()
265
+
266
+ # Extract if body
267
+ if_body = body .body
268
+
269
+ # Check for a single body
270
+ if len (if_body ) != 1 :
271
+ warnings .warn (parse_warning )
272
+ return set (), set ()
273
+
274
+ # Extract class body
275
+ cls_body = if_body [0 ]
276
+
277
+ # Check for a single class definition
278
+ if not isinstance (cls_body , ast .ClassDef ):
279
+ warnings .warn (parse_warning )
280
+ return set (), set ()
281
+
282
+ # Get line numbers of assign statements
283
+ intermediate_assign_lines = set ()
284
+ final_assign_lines = set ()
285
+ for node in cls_body .body :
286
+ if isinstance (node , (ast .Assign , ast .AnnAssign )):
287
+ # Check if the end line number is found
288
+ if node .end_lineno is None :
289
+ warnings .warn (parse_warning )
290
+ continue
291
+
292
+ # Only consider multiline assign statements
293
+ if node .end_lineno > node .lineno :
294
+ # Get intermediate line number of assign statement excluding the first line (and minus 1 for the if statement)
295
+ intermediate_assign_lines |= set (range (node .lineno , node .end_lineno - 1 ))
296
+
297
+ # If multiline assign statement, get the line number of the last line (and minus 1 for the if statement)
298
+ final_assign_lines .add (node .end_lineno - 1 )
299
+
300
+ return intermediate_assign_lines , final_assign_lines
301
+
302
+
243
303
def get_class_variables (cls : type ) -> Dict [str , Dict [str , str ]]:
244
304
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
305
+ # Get the source code and tokens of the class
306
+ source_cls = inspect .getsource (cls )
307
+ tokens = tuple (tokenize_source (source_cls ))
308
+
245
309
# Get mapping from line number to tokens
246
- line_to_tokens = source_line_to_tokens (cls )
310
+ line_to_tokens = source_line_to_tokens (tokens )
247
311
248
312
# Get class variable column number
249
- class_variable_column = get_class_column (cls )
313
+ class_variable_column = get_class_column (tokens )
314
+
315
+ # For all multiline assign statements, get the line numbers after the first line of the assignment
316
+ # This is used to avoid identifying comments in multiline assign statements
317
+ intermediate_assign_lines , final_assign_lines = get_subsequent_assign_lines (source_cls )
250
318
251
319
# Extract class variables
252
320
class_variable = None
253
321
variable_to_comment = {}
254
- for tokens in line_to_tokens .values ():
255
- for i , token in enumerate (tokens ):
322
+ for line , tokens in line_to_tokens .items ():
323
+ # If this is the final line of a multiline assign, extract any potential comments
324
+ if line in final_assign_lines :
325
+ # Find the comment (if it exists)
326
+ for token in tokens :
327
+ if token ["token_type" ] == tokenize .COMMENT :
328
+ # Leave out "#" and whitespace from comment
329
+ variable_to_comment [class_variable ]["comment" ] = token ["token" ][1 :].strip ()
330
+ break
331
+ continue
332
+
333
+ # Skip assign lines after the first line of multiline assign statements
334
+ if line in intermediate_assign_lines :
335
+ continue
256
336
337
+ for i , token in enumerate (tokens ):
257
338
# Skip whitespace
258
339
if token ["token" ].strip () == "" :
259
340
continue
@@ -265,8 +346,21 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
265
346
and token ["token" ][:1 ] in {'"' , "'" }
266
347
):
267
348
sep = " " if variable_to_comment [class_variable ]["comment" ] else ""
349
+
350
+ # Identify the quote character (single or double)
268
351
quote_char = token ["token" ][:1 ]
269
- variable_to_comment [class_variable ]["comment" ] += sep + token ["token" ].strip (quote_char ).strip ()
352
+
353
+ # Identify the number of quote characters at the start of the string
354
+ num_quote_chars = len (token ["token" ]) - len (token ["token" ].lstrip (quote_char ))
355
+
356
+ # Remove the number of quote characters at the start of the string and the end of the string
357
+ token ["token" ] = token ["token" ][num_quote_chars :- num_quote_chars ]
358
+
359
+ # Remove the unicode escape sequences (e.g. "\"")
360
+ token ["token" ] = bytes (token ["token" ], encoding = "ascii" ).decode ("unicode-escape" )
361
+
362
+ # Add the token to the comment, stripping whitespace
363
+ variable_to_comment [class_variable ]["comment" ] += sep + token ["token" ].strip ()
270
364
271
365
# Match class variable
272
366
class_variable = None
@@ -292,7 +386,7 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
292
386
return variable_to_comment
293
387
294
388
295
- def get_literals (literal : Literal , variable : str ) -> Tuple [Callable [[str ], Any ], List [str ]]:
389
+ def get_literals (literal : Literal , variable : str ) -> Tuple [Callable [[str ], Any ], List [type ]]:
296
390
"""Extracts the values from a Literal type and ensures that the values are all primitive types."""
297
391
literals = list (get_args (literal ))
298
392
@@ -449,33 +543,6 @@ def as_python_object(dct: Any) -> Any:
449
543
return dct
450
544
451
545
452
- def fix_py36_copy (func : Callable ) -> Callable :
453
- """Decorator that fixes functions using Python 3.6 deepcopy of ArgumentParsers.
454
-
455
- Based on https://stackoverflow.com/questions/6279305/typeerror-cannot-deepcopy-this-pattern-object
456
- """
457
- if sys .version_info [:2 ] > (3 , 6 ):
458
- return func
459
-
460
- @wraps (func )
461
- def wrapper (* args , ** kwargs ):
462
- re_type = type (re .compile ("" ))
463
- has_prev_val = re_type in copy ._deepcopy_dispatch
464
- prev_val = copy ._deepcopy_dispatch .get (re_type , None )
465
- copy ._deepcopy_dispatch [type (re .compile ("" ))] = lambda r , _ : r
466
-
467
- result = func (* args , ** kwargs )
468
-
469
- if has_prev_val :
470
- copy ._deepcopy_dispatch [re_type ] = prev_val
471
- else :
472
- del copy ._deepcopy_dispatch [re_type ]
473
-
474
- return result
475
-
476
- return wrapper
477
-
478
-
479
546
def enforce_reproducibility (
480
547
saved_reproducibility_data : Optional [Dict [str , str ]], current_reproducibility_data : Dict [str , str ], path : PathLike
481
548
) -> None :
@@ -512,7 +579,7 @@ def enforce_reproducibility(
512
579
raise ValueError (f"{ no_reproducibility_message } : Uncommitted changes " f"in current args." )
513
580
514
581
515
- # TODO: remove this once typing_inspect.get_origin is fixed for Python 3.8, 3.9, and 3.10
582
+ # TODO: remove this once typing_inspect.get_origin is fixed for Python 3.9 and 3.10
516
583
# https://github.com/ilevkivskyi/typing_inspect/issues/64
517
584
# https://github.com/ilevkivskyi/typing_inspect/issues/65
518
585
def get_origin (tp : Any ) -> Any :
0 commit comments