Skip to content

Commit 792da8d

Browse files
committed
Merge branch 'main' into git-no-remote
2 parents 7ccd355 + fe0d2d8 commit 792da8d

12 files changed

+198
-96
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
matrix:
1818
os: [ubuntu-latest, macos-latest, windows-latest]
19-
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
19+
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
2020

2121
steps:
2222
- uses: actions/checkout@main

LICENSE.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2022 Jesse Michel and Kyle Swanson
1+
Copyright (c) 2024 Jesse Michel and Kyle Swanson
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Running `python square.py --num 2` will print `The square of your number is 4.0.
3939

4040
## Installation
4141

42-
Tap requires Python 3.8+
42+
Tap requires Python 3.9+
4343

4444
To install Tap from PyPI run:
4545

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ dependencies = [
2121
"packaging",
2222
"typing-inspect >= 0.7.1",
2323
]
24-
requires-python = ">=3.8"
24+
requires-python = ">=3.9"
2525
classifiers = [
2626
"Programming Language :: Python :: 3",
27-
"Programming Language :: Python :: 3.8",
2827
"Programming Language :: Python :: 3.9",
2928
"Programming Language :: Python :: 3.10",
3029
"Programming Language :: Python :: 3.11",
3130
"Programming Language :: Python :: 3.12",
31+
"Programming Language :: Python :: 3.13",
3232
"License :: OSI Approved :: MIT License",
3333
"Operating System :: OS Independent",
3434
"Typing :: Typed",

src/tap/tap.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
TupleTypeEnforcer,
2727
define_python_object_encoder,
2828
as_python_object,
29-
fix_py36_copy,
3029
enforce_reproducibility,
3130
PathLike,
3231
)
@@ -227,7 +226,7 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
227226
# Handle Tuple type (with type args) by extracting types of Tuple elements and enforcing them
228227
elif get_origin(var_type) in (Tuple, tuple) and len(get_args(var_type)) > 0:
229228
loop = False
230-
types = get_args(var_type)
229+
types = list(get_args(var_type))
231230

232231
# Handle Tuple[type, ...]
233232
if len(types) == 2 and types[1] == Ellipsis:
@@ -504,7 +503,7 @@ def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union
504503
while len(super_classes) > 0:
505504
super_class = super_classes.pop(0)
506505

507-
if super_class not in visited and issubclass(super_class, Tap):
506+
if super_class not in visited and issubclass(super_class, Tap) and super_class is not Tap:
508507
super_dictionary = extract_func(super_class)
509508

510509
# Update only unseen variables to avoid overriding subclass values
@@ -527,13 +526,7 @@ def _get_class_dict(self) -> Dict[str, Any]:
527526
class_dict = {
528527
var: val
529528
for var, val in class_dict.items()
530-
if not (
531-
var.startswith("_")
532-
or callable(val)
533-
or isinstance(val, staticmethod)
534-
or isinstance(val, classmethod)
535-
or isinstance(val, property)
536-
)
529+
if not (var.startswith("_") or callable(val) or isinstance(val, (staticmethod, classmethod, property)))
537530
}
538531

539532
return class_dict
@@ -547,9 +540,7 @@ def _get_class_variables(self) -> dict:
547540
class_variable_names = {**self._get_annotations(), **self._get_class_dict()}.keys()
548541

549542
try:
550-
class_variables = self._get_from_self_and_super(
551-
extract_func=lambda super_class: get_class_variables(super_class)
552-
)
543+
class_variables = self._get_from_self_and_super(extract_func=get_class_variables)
553544

554545
# Handle edge-case of source code modification while code is running
555546
variables_to_add = class_variable_names - class_variables.keys()
@@ -717,7 +708,6 @@ def __str__(self) -> str:
717708
"""
718709
return pformat(self.as_dict())
719710

720-
@fix_py36_copy
721711
def __deepcopy__(self, memo: Dict[int, Any] = None) -> TapType:
722712
"""Deepcopy the Tap object."""
723713
copied = type(self).__new__(type(self))

src/tap/utils.py

Lines changed: 114 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from argparse import ArgumentParser, ArgumentTypeError
2+
import ast
23
from base64 import b64encode, b64decode
34
import copy
45
from functools import wraps
@@ -10,20 +11,24 @@
1011
import re
1112
import subprocess
1213
import sys
14+
import textwrap
1315
import tokenize
1416
from typing import (
1517
Any,
1618
Callable,
1719
Dict,
1820
Generator,
21+
Iterable,
1922
Iterator,
2023
List,
2124
Literal,
2225
Optional,
26+
Set,
2327
Tuple,
2428
Union,
2529
)
2630
from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin
31+
import warnings
2732

2833
if sys.version_info >= (3, 10):
2934
from types import UnionType
@@ -162,7 +167,7 @@ def get_argument_name(*name_or_flags) -> str:
162167
return "help"
163168

164169
if len(name_or_flags) > 1:
165-
name_or_flags = [n_or_f for n_or_f in name_or_flags if n_or_f.startswith("--")]
170+
name_or_flags = tuple(n_or_f for n_or_f in name_or_flags if n_or_f.startswith("--"))
166171

167172
if len(name_or_flags) != 1:
168173
raise ValueError(f"There should only be a single canonical name for argument {name_or_flags}!")
@@ -201,30 +206,28 @@ def is_positional_arg(*name_or_flags) -> bool:
201206
return not is_option_arg(*name_or_flags)
202207

203208

204-
def tokenize_source(obj: object) -> Generator:
205-
"""Returns a generator for the tokens of the object's source code."""
206-
source = inspect.getsource(obj)
207-
token_generator = tokenize.generate_tokens(StringIO(source).readline)
209+
def tokenize_source(source: str) -> Generator[tokenize.TokenInfo, None, None]:
210+
"""Returns a generator for the tokens of the object's source code, given the source code."""
211+
return tokenize.generate_tokens(StringIO(source).readline)
208212

209-
return token_generator
210213

211-
212-
def get_class_column(obj: type) -> int:
213-
"""Determines the column number for class variables in a class."""
214+
def get_class_column(tokens: Iterable[tokenize.TokenInfo]) -> int:
215+
"""Determines the column number for class variables in a class, given the tokens of the class."""
214216
first_line = 1
215-
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
217+
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
216218
if token.strip() == "@":
217219
first_line += 1
218220
if start_line <= first_line or token.strip() == "":
219221
continue
220222

221223
return start_column
224+
raise ValueError("Could not find any class variables in the class.")
222225

223226

224-
def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, int]]]]:
225-
"""Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code."""
227+
def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> Dict[int, List[Dict[str, Union[str, int]]]]:
228+
"""Extract a map from each line number to list of mappings providing information about each token."""
226229
line_to_tokens = {}
227-
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
230+
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
228231
line_to_tokens.setdefault(start_line, []).append(
229232
{
230233
"token_type": token_type,
@@ -240,20 +243,98 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in
240243
return line_to_tokens
241244

242245

246+
def get_subsequent_assign_lines(source_cls: str) -> Tuple[Set[int], Set[int]]:
247+
"""For all multiline assign statements, get the line numbers after the first line in the assignment.
248+
249+
:param source_cls: The source code of the class.
250+
:return: A set of intermediate line numbers for multiline assign statements and a set of final line numbers.
251+
"""
252+
# Parse source code using ast (with an if statement to avoid indentation errors)
253+
source = f"if True:\n{textwrap.indent(source_cls, ' ')}"
254+
body = ast.parse(source).body[0]
255+
256+
# Set up warning message
257+
parse_warning = (
258+
"Could not parse class source code to extract comments. Comments in the help string may be incorrect."
259+
)
260+
261+
# Check for correct parsing
262+
if not isinstance(body, ast.If):
263+
warnings.warn(parse_warning)
264+
return set(), set()
265+
266+
# Extract if body
267+
if_body = body.body
268+
269+
# Check for a single body
270+
if len(if_body) != 1:
271+
warnings.warn(parse_warning)
272+
return set(), set()
273+
274+
# Extract class body
275+
cls_body = if_body[0]
276+
277+
# Check for a single class definition
278+
if not isinstance(cls_body, ast.ClassDef):
279+
warnings.warn(parse_warning)
280+
return set(), set()
281+
282+
# Get line numbers of assign statements
283+
intermediate_assign_lines = set()
284+
final_assign_lines = set()
285+
for node in cls_body.body:
286+
if isinstance(node, (ast.Assign, ast.AnnAssign)):
287+
# Check if the end line number is found
288+
if node.end_lineno is None:
289+
warnings.warn(parse_warning)
290+
continue
291+
292+
# Only consider multiline assign statements
293+
if node.end_lineno > node.lineno:
294+
# Get intermediate line number of assign statement excluding the first line (and minus 1 for the if statement)
295+
intermediate_assign_lines |= set(range(node.lineno, node.end_lineno - 1))
296+
297+
# If multiline assign statement, get the line number of the last line (and minus 1 for the if statement)
298+
final_assign_lines.add(node.end_lineno - 1)
299+
300+
return intermediate_assign_lines, final_assign_lines
301+
302+
243303
def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
244304
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
305+
# Get the source code and tokens of the class
306+
source_cls = inspect.getsource(cls)
307+
tokens = tuple(tokenize_source(source_cls))
308+
245309
# Get mapping from line number to tokens
246-
line_to_tokens = source_line_to_tokens(cls)
310+
line_to_tokens = source_line_to_tokens(tokens)
247311

248312
# Get class variable column number
249-
class_variable_column = get_class_column(cls)
313+
class_variable_column = get_class_column(tokens)
314+
315+
# For all multiline assign statements, get the line numbers after the first line of the assignment
316+
# This is used to avoid identifying comments in multiline assign statements
317+
intermediate_assign_lines, final_assign_lines = get_subsequent_assign_lines(source_cls)
250318

251319
# Extract class variables
252320
class_variable = None
253321
variable_to_comment = {}
254-
for tokens in line_to_tokens.values():
255-
for i, token in enumerate(tokens):
322+
for line, tokens in line_to_tokens.items():
323+
# If this is the final line of a multiline assign, extract any potential comments
324+
if line in final_assign_lines:
325+
# Find the comment (if it exists)
326+
for token in tokens:
327+
if token["token_type"] == tokenize.COMMENT:
328+
# Leave out "#" and whitespace from comment
329+
variable_to_comment[class_variable]["comment"] = token["token"][1:].strip()
330+
break
331+
continue
332+
333+
# Skip assign lines after the first line of multiline assign statements
334+
if line in intermediate_assign_lines:
335+
continue
256336

337+
for i, token in enumerate(tokens):
257338
# Skip whitespace
258339
if token["token"].strip() == "":
259340
continue
@@ -265,8 +346,21 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
265346
and token["token"][:1] in {'"', "'"}
266347
):
267348
sep = " " if variable_to_comment[class_variable]["comment"] else ""
349+
350+
# Identify the quote character (single or double)
268351
quote_char = token["token"][:1]
269-
variable_to_comment[class_variable]["comment"] += sep + token["token"].strip(quote_char).strip()
352+
353+
# Identify the number of quote characters at the start of the string
354+
num_quote_chars = len(token["token"]) - len(token["token"].lstrip(quote_char))
355+
356+
# Remove the number of quote characters at the start of the string and the end of the string
357+
token["token"] = token["token"][num_quote_chars:-num_quote_chars]
358+
359+
# Remove the unicode escape sequences (e.g. "\"")
360+
token["token"] = bytes(token["token"], encoding="ascii").decode("unicode-escape")
361+
362+
# Add the token to the comment, stripping whitespace
363+
variable_to_comment[class_variable]["comment"] += sep + token["token"].strip()
270364

271365
# Match class variable
272366
class_variable = None
@@ -292,7 +386,7 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
292386
return variable_to_comment
293387

294388

295-
def get_literals(literal: Literal, variable: str) -> Tuple[Callable[[str], Any], List[str]]:
389+
def get_literals(literal: Literal, variable: str) -> Tuple[Callable[[str], Any], List[type]]:
296390
"""Extracts the values from a Literal type and ensures that the values are all primitive types."""
297391
literals = list(get_args(literal))
298392

@@ -449,33 +543,6 @@ def as_python_object(dct: Any) -> Any:
449543
return dct
450544

451545

452-
def fix_py36_copy(func: Callable) -> Callable:
453-
"""Decorator that fixes functions using Python 3.6 deepcopy of ArgumentParsers.
454-
455-
Based on https://stackoverflow.com/questions/6279305/typeerror-cannot-deepcopy-this-pattern-object
456-
"""
457-
if sys.version_info[:2] > (3, 6):
458-
return func
459-
460-
@wraps(func)
461-
def wrapper(*args, **kwargs):
462-
re_type = type(re.compile(""))
463-
has_prev_val = re_type in copy._deepcopy_dispatch
464-
prev_val = copy._deepcopy_dispatch.get(re_type, None)
465-
copy._deepcopy_dispatch[type(re.compile(""))] = lambda r, _: r
466-
467-
result = func(*args, **kwargs)
468-
469-
if has_prev_val:
470-
copy._deepcopy_dispatch[re_type] = prev_val
471-
else:
472-
del copy._deepcopy_dispatch[re_type]
473-
474-
return result
475-
476-
return wrapper
477-
478-
479546
def enforce_reproducibility(
480547
saved_reproducibility_data: Optional[Dict[str, str]], current_reproducibility_data: Dict[str, str], path: PathLike
481548
) -> None:
@@ -512,7 +579,7 @@ def enforce_reproducibility(
512579
raise ValueError(f"{no_reproducibility_message}: Uncommitted changes " f"in current args.")
513580

514581

515-
# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.8, 3.9, and 3.10
582+
# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.9 and 3.10
516583
# https://github.com/ilevkivskyi/typing_inspect/issues/64
517584
# https://github.com/ilevkivskyi/typing_inspect/issues/65
518585
def get_origin(tp: Any) -> Any:

tests/test_actions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def configure(self):
171171
# tried redirecting stderr using unittest.mock.patch
172172
# VersionTap().parse_args(['--version'])
173173

174-
@unittest.skipIf(sys.version_info < (3, 8), 'action="extend" introduced in argparse in Python 3.8')
175174
def test_actions_extend(self):
176175
class ExtendTap(Tap):
177176
arg = [1, 2]
@@ -185,7 +184,6 @@ def configure(self):
185184
args = ExtendTap().parse_args("--arg a b --arg a --arg c d".split())
186185
self.assertEqual(args.arg, [1, 2] + "a b a c d".split())
187186

188-
@unittest.skipIf(sys.version_info < (3, 8), 'action="extend" introduced in argparse in Python 3.8')
189187
def test_actions_extend_list(self):
190188
class ExtendListTap(Tap):
191189
arg: List = ["hi"]
@@ -196,7 +194,6 @@ def configure(self):
196194
args = ExtendListTap().parse_args("--arg yo yo --arg yoyo --arg yo yo".split())
197195
self.assertEqual(args.arg, "hi yo yo yoyo yo yo".split())
198196

199-
@unittest.skipIf(sys.version_info < (3, 8), 'action="extend" introduced in argparse in Python 3.8')
200197
def test_actions_extend_list_int(self):
201198
class ExtendListIntTap(Tap):
202199
arg: List[int] = [0]

0 commit comments

Comments
 (0)