Skip to content

Commit 7ea9791

Browse files
authored
Merge pull request #149 from arnaud-ma/better-performance
Improve performance by: (1) When extracting arguments from descendants (via inheritance) don't extract argument from the top-level Tap class. (2) When extracting docstrings, only read source code once.
2 parents c0d4b75 + 09cb610 commit 7ea9791

File tree

3 files changed

+49
-32
lines changed

3 files changed

+49
-32
lines changed

src/tap/tap.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union
503503
while len(super_classes) > 0:
504504
super_class = super_classes.pop(0)
505505

506-
if super_class not in visited and issubclass(super_class, Tap):
506+
if super_class not in visited and issubclass(super_class, Tap) and super_class is not Tap:
507507
super_dictionary = extract_func(super_class)
508508

509509
# Update only unseen variables to avoid overriding subclass values
@@ -529,9 +529,7 @@ def _get_class_dict(self) -> Dict[str, Any]:
529529
if not (
530530
var.startswith("_")
531531
or callable(val)
532-
or isinstance(val, staticmethod)
533-
or isinstance(val, classmethod)
534-
or isinstance(val, property)
532+
or isinstance(val, (staticmethod, classmethod, property))
535533
)
536534
}
537535

@@ -546,9 +544,7 @@ def _get_class_variables(self) -> dict:
546544
class_variable_names = {**self._get_annotations(), **self._get_class_dict()}.keys()
547545

548546
try:
549-
class_variables = self._get_from_self_and_super(
550-
extract_func=lambda super_class: get_class_variables(super_class)
551-
)
547+
class_variables = self._get_from_self_and_super(extract_func=get_class_variables)
552548

553549
# Handle edge-case of source code modification while code is running
554550
variables_to_add = class_variable_names - class_variables.keys()

src/tap/utils.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Callable,
1919
Dict,
2020
Generator,
21+
Iterable,
2122
Iterator,
2223
List,
2324
Literal,
@@ -184,29 +185,31 @@ def is_positional_arg(*name_or_flags) -> bool:
184185
return not is_option_arg(*name_or_flags)
185186

186187

187-
def tokenize_source(obj: object) -> Generator:
188-
"""Returns a generator for the tokens of the object's source code."""
189-
source = inspect.getsource(obj)
190-
token_generator = tokenize.generate_tokens(StringIO(source).readline)
191-
return token_generator
188+
def tokenize_source(source: str) -> Generator[tokenize.TokenInfo, None, None]:
189+
"""Returns a generator for the tokens of the object's source code, given the source code."""
190+
return tokenize.generate_tokens(StringIO(source).readline)
192191

193192

194-
def get_class_column(obj: type) -> int:
195-
"""Determines the column number for class variables in a class."""
193+
def get_class_column(tokens: Iterable[tokenize.TokenInfo]) -> int:
194+
"""Determines the column number for class variables in a class, given the tokens of the class."""
196195
first_line = 1
197-
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
196+
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
198197
if token.strip() == "@":
199198
first_line += 1
200199
if start_line <= first_line or token.strip() == "":
201200
continue
202201

203202
return start_column
203+
raise ValueError("Could not find any class variables in the class.")
204204

205205

206-
def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, int]]]]:
207-
"""Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code."""
206+
def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> Dict[int, List[Dict[str, Union[str, int]]]]:
207+
"""
208+
Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code,
209+
given the tokens of the object's source code.
210+
"""
208211
line_to_tokens = {}
209-
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
212+
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
210213
line_to_tokens.setdefault(start_line, []).append({
211214
'token_type': token_type,
212215
'token': token,
@@ -220,13 +223,14 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in
220223
return line_to_tokens
221224

222225

223-
def get_subsequent_assign_lines(cls: type) -> Set[int]:
224-
"""For all multiline assign statements, get the line numbers after the first line of the assignment."""
225-
# Get source code of class
226-
source = inspect.getsource(cls)
226+
def get_subsequent_assign_lines(source_cls: str) -> Set[int]:
227+
"""
228+
For all multiline assign statements, get the line numbers after the first line of the assignment,
229+
given the source code of the object.
230+
"""
227231

228232
# Parse source code using ast (with an if statement to avoid indentation errors)
229-
source = f"if True:\n{textwrap.indent(source, ' ')}"
233+
source = f"if True:\n{textwrap.indent(source_cls, ' ')}"
230234
body = ast.parse(source).body[0]
231235

232236
# Set up warning message
@@ -260,6 +264,11 @@ def get_subsequent_assign_lines(cls: type) -> Set[int]:
260264
assign_lines = set()
261265
for node in cls_body.body:
262266
if isinstance(node, (ast.Assign, ast.AnnAssign)):
267+
# Check if the end line number is found
268+
if node.end_lineno is None:
269+
warnings.warn(parse_warning)
270+
continue
271+
263272
# Get line number of assign statement excluding the first line (and minus 1 for the if statement)
264273
assign_lines |= set(range(node.lineno, node.end_lineno))
265274

@@ -268,15 +277,19 @@ def get_subsequent_assign_lines(cls: type) -> Set[int]:
268277

269278
def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
270279
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
280+
# Get the source code and tokens of the class
281+
source_cls = inspect.getsource(cls)
282+
tokens = tuple(tokenize_source(source_cls))
283+
271284
# Get mapping from line number to tokens
272-
line_to_tokens = source_line_to_tokens(cls)
285+
line_to_tokens = source_line_to_tokens(tokens)
273286

274287
# Get class variable column number
275-
class_variable_column = get_class_column(cls)
288+
class_variable_column = get_class_column(tokens)
276289

277290
# For all multiline assign statements, get the line numbers after the first line of the assignment
278291
# This is used to avoid identifying comments in multiline assign statements
279-
subsequent_assign_lines = get_subsequent_assign_lines(cls)
292+
subsequent_assign_lines = get_subsequent_assign_lines(source_cls)
280293

281294
# Extract class variables
282295
class_variable = None

tests/test_utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from argparse import ArgumentTypeError
2+
import inspect
23
import json
34
import os
45
import subprocess
@@ -11,6 +12,7 @@
1112
get_class_column,
1213
get_class_variables,
1314
GitInfo,
15+
tokenize_source,
1416
type_to_str,
1517
get_literals,
1618
TupleTypeEnforcer,
@@ -145,7 +147,8 @@ def test_column_simple(self):
145147
class SimpleColumn:
146148
arg = 2
147149

148-
self.assertEqual(get_class_column(SimpleColumn), 12)
150+
tokens = tokenize_source(inspect.getsource(SimpleColumn))
151+
self.assertEqual(get_class_column(tokens), 12)
149152

150153
def test_column_comment(self):
151154
class CommentColumn:
@@ -158,28 +161,32 @@ class CommentColumn:
158161

159162
arg = 2
160163

161-
self.assertEqual(get_class_column(CommentColumn), 12)
164+
tokens = tokenize_source(inspect.getsource(CommentColumn))
165+
self.assertEqual(get_class_column(tokens), 12)
162166

163167
def test_column_space(self):
164168
class SpaceColumn:
165169

166170
arg = 2
167171

168-
self.assertEqual(get_class_column(SpaceColumn), 12)
172+
tokens = tokenize_source(inspect.getsource(SpaceColumn))
173+
self.assertEqual(get_class_column(tokens), 12)
169174

170175
def test_column_method(self):
171176
class FuncColumn:
172177
def func(self):
173178
pass
174179

175-
self.assertEqual(get_class_column(FuncColumn), 12)
180+
tokens = tokenize_source(inspect.getsource(FuncColumn))
181+
self.assertEqual(get_class_column(tokens), 12)
176182

177183
def test_dataclass(self):
178184
@class_decorator
179185
class DataclassColumn:
180186
arg: int = 5
181187

182-
self.assertEqual(get_class_column(DataclassColumn), 12)
188+
tokens = tokenize_source(inspect.getsource(DataclassColumn))
189+
self.assertEqual(get_class_column(tokens), 12)
183190

184191
def test_dataclass_method(self):
185192
def wrapper(f):
@@ -191,7 +198,8 @@ class DataclassColumn:
191198
def func(self):
192199
pass
193200

194-
self.assertEqual(get_class_column(DataclassColumn), 12)
201+
tokens = tokenize_source(inspect.getsource(DataclassColumn))
202+
self.assertEqual(get_class_column(tokens), 12)
195203

196204

197205
class ClassVariableTests(TestCase):

0 commit comments

Comments
 (0)