Skip to content

Commit 69f78d3

Browse files
committed
Fixing comment extraction in the case of multiline assign statements using ast
1 parent 6405e30 commit 69f78d3

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

src/tap/utils.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from argparse import ArgumentParser, ArgumentTypeError
2+
import ast
23
from base64 import b64encode, b64decode
34
import copy
45
from functools import wraps
@@ -10,6 +11,7 @@
1011
import re
1112
import subprocess
1213
import sys
14+
import textwrap
1315
import tokenize
1416
from typing import (
1517
Any,
@@ -24,6 +26,7 @@
2426
Union,
2527
)
2628
from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin
29+
import warnings
2730

2831
if sys.version_info >= (3, 10):
2932
from types import UnionType
@@ -216,6 +219,52 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in
216219
return line_to_tokens
217220

218221

222+
def get_subsequent_assign_lines(cls: type) -> set[int]:
223+
"""For all multiline assign statements, get the line numbers after the first line of the assignment."""
224+
# Get source code of class
225+
source = inspect.getsource(cls)
226+
227+
# Parse source code using ast (with an if statement to avoid indentation errors)
228+
source = f"if True:\n{textwrap.indent(source, ' ')}"
229+
body = ast.parse(source).body[0]
230+
231+
# Set up warning message
232+
parse_warning = (
233+
"Could not parse class source code to extract comments. "
234+
"Comments in the help string may be incorrect."
235+
)
236+
237+
# Check for correct parsing
238+
if not isinstance(body, ast.If):
239+
warnings.warn(parse_warning)
240+
return set()
241+
242+
# Extract if body
243+
if_body = body.body
244+
245+
# Check for a single body
246+
if len(if_body) != 1:
247+
warnings.warn(parse_warning)
248+
return set()
249+
250+
# Extract class body
251+
cls_body = if_body[0]
252+
253+
# Check for a single class definition
254+
if not isinstance(cls_body, ast.ClassDef):
255+
warnings.warn(parse_warning)
256+
return set()
257+
258+
# Get line numbers of assign statements
259+
assign_lines = set()
260+
for node in cls_body.body:
261+
if isinstance(node, (ast.Assign, ast.AnnAssign)):
262+
# Get line number of assign statement excluding the first line (and minus 1 for the if statement)
263+
assign_lines |= set(range(node.lineno, node.end_lineno))
264+
265+
return assign_lines
266+
267+
219268
def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
220269
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
221270
# Get mapping from line number to tokens
@@ -224,12 +273,19 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
224273
# Get class variable column number
225274
class_variable_column = get_class_column(cls)
226275

276+
# For all multiline assign statements, get the line numbers after the first line of the assignment
277+
# This is used to avoid identifying comments in multiline assign statements
278+
subsequent_assign_lines = get_subsequent_assign_lines(cls)
279+
227280
# Extract class variables
228281
class_variable = None
229282
variable_to_comment = {}
230-
for tokens in line_to_tokens.values():
231-
for i, token in enumerate(tokens):
283+
for line, tokens in line_to_tokens.items():
284+
# Skip assign lines after the first line of multiline assign statements
285+
if line in subsequent_assign_lines:
286+
continue
232287

288+
for i, token in enumerate(tokens):
233289
# Skip whitespace
234290
if token["token"].strip() == "":
235291
continue

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ class MultilineArgument:
319319
"This is a multiline argument"
320320
" that should not be included in the docstring"
321321
)
322-
("""biz baz""")
322+
"""biz baz"""
323323

324324
class_variables = {"bar": {"comment": "biz baz"}}
325325
self.assertEqual(get_class_variables(MultilineArgument), class_variables)

0 commit comments

Comments
 (0)