Skip to content

Commit c0d4b75

Browse files
authored
Merge pull request #148 from swansonk14/quote-docstrings
Fixing issues with comment extraction for the help string
2 parents 5257fe8 + 3980c81 commit c0d4b75

File tree

2 files changed

+106
-15
lines changed

2 files changed

+106
-15
lines changed

src/tap/utils.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from argparse import ArgumentParser, ArgumentTypeError
2+
import ast
23
from base64 import b64encode, b64decode
34
import copy
45
from functools import wraps
@@ -10,6 +11,7 @@
1011
import re
1112
import subprocess
1213
import sys
14+
import textwrap
1315
import tokenize
1416
from typing import (
1517
Any,
@@ -20,10 +22,12 @@
2022
List,
2123
Literal,
2224
Optional,
25+
Set,
2326
Tuple,
2427
Union,
2528
)
2629
from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin
30+
import warnings
2731

2832
if sys.version_info >= (3, 10):
2933
from types import UnionType
@@ -184,7 +188,6 @@ def tokenize_source(obj: object) -> Generator:
184188
"""Returns a generator for the tokens of the object's source code."""
185189
source = inspect.getsource(obj)
186190
token_generator = tokenize.generate_tokens(StringIO(source).readline)
187-
188191
return token_generator
189192

190193

@@ -204,21 +207,65 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in
204207
"""Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code."""
205208
line_to_tokens = {}
206209
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
207-
line_to_tokens.setdefault(start_line, []).append(
208-
{
209-
"token_type": token_type,
210-
"token": token,
211-
"start_line": start_line,
212-
"start_column": start_column,
213-
"end_line": end_line,
214-
"end_column": end_column,
215-
"line": line,
216-
}
217-
)
210+
line_to_tokens.setdefault(start_line, []).append({
211+
'token_type': token_type,
212+
'token': token,
213+
'start_line': start_line,
214+
'start_column': start_column,
215+
'end_line': end_line,
216+
'end_column': end_column,
217+
'line': line
218+
})
218219

219220
return line_to_tokens
220221

221222

223+
def get_subsequent_assign_lines(cls: type) -> Set[int]:
224+
"""For all multiline assign statements, get the line numbers after the first line of the assignment."""
225+
# Get source code of class
226+
source = inspect.getsource(cls)
227+
228+
# Parse source code using ast (with an if statement to avoid indentation errors)
229+
source = f"if True:\n{textwrap.indent(source, ' ')}"
230+
body = ast.parse(source).body[0]
231+
232+
# Set up warning message
233+
parse_warning = (
234+
"Could not parse class source code to extract comments. "
235+
"Comments in the help string may be incorrect."
236+
)
237+
238+
# Check for correct parsing
239+
if not isinstance(body, ast.If):
240+
warnings.warn(parse_warning)
241+
return set()
242+
243+
# Extract if body
244+
if_body = body.body
245+
246+
# Check for a single body
247+
if len(if_body) != 1:
248+
warnings.warn(parse_warning)
249+
return set()
250+
251+
# Extract class body
252+
cls_body = if_body[0]
253+
254+
# Check for a single class definition
255+
if not isinstance(cls_body, ast.ClassDef):
256+
warnings.warn(parse_warning)
257+
return set()
258+
259+
# Get line numbers of assign statements
260+
assign_lines = set()
261+
for node in cls_body.body:
262+
if isinstance(node, (ast.Assign, ast.AnnAssign)):
263+
# Get line number of assign statement excluding the first line (and minus 1 for the if statement)
264+
assign_lines |= set(range(node.lineno, node.end_lineno))
265+
266+
return assign_lines
267+
268+
222269
def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
223270
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
224271
# Get mapping from line number to tokens
@@ -227,12 +274,19 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
227274
# Get class variable column number
228275
class_variable_column = get_class_column(cls)
229276

277+
# For all multiline assign statements, get the line numbers after the first line of the assignment
278+
# This is used to avoid identifying comments in multiline assign statements
279+
subsequent_assign_lines = get_subsequent_assign_lines(cls)
280+
230281
# Extract class variables
231282
class_variable = None
232283
variable_to_comment = {}
233-
for tokens in line_to_tokens.values():
234-
for i, token in enumerate(tokens):
284+
for line, tokens in line_to_tokens.items():
285+
# Skip assign lines after the first line of multiline assign statements
286+
if line in subsequent_assign_lines:
287+
continue
235288

289+
for i, token in enumerate(tokens):
236290
# Skip whitespace
237291
if token["token"].strip() == "":
238292
continue
@@ -244,8 +298,21 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
244298
and token["token"][:1] in {'"', "'"}
245299
):
246300
sep = " " if variable_to_comment[class_variable]["comment"] else ""
301+
302+
# Identify the quote character (single or double)
247303
quote_char = token["token"][:1]
248-
variable_to_comment[class_variable]["comment"] += sep + token["token"].strip(quote_char).strip()
304+
305+
# Identify the number of quote characters at the start of the string
306+
num_quote_chars = len(token["token"]) - len(token["token"].lstrip(quote_char))
307+
308+
# Remove the number of quote characters at the start of the string and the end of the string
309+
token["token"] = token["token"][num_quote_chars:-num_quote_chars]
310+
311+
# Remove the unicode escape sequences (e.g. "\"")
312+
token["token"] = bytes(token["token"], encoding='ascii').decode('unicode-escape')
313+
314+
# Add the token to the comment, stripping whitespace
315+
variable_to_comment[class_variable]["comment"] += sep + token["token"].strip()
249316

250317
# Match class variable
251318
class_variable = None

tests/test_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,30 @@ class TripleQuoteMultiline:
300300
class_variables = {"bar": {"comment": "biz baz"}, "hi": {"comment": "Hello there"}}
301301
self.assertEqual(get_class_variables(TripleQuoteMultiline), class_variables)
302302

303+
def test_comments_with_quotes(self):
304+
class MultiquoteMultiline:
305+
bar: int = 0
306+
'\'\'biz baz\''
307+
308+
hi: str
309+
"\"Hello there\"\""
310+
311+
class_variables = {}
312+
class_variables['bar'] = {'comment': "''biz baz'"}
313+
class_variables['hi'] = {'comment': '"Hello there""'}
314+
self.assertEqual(get_class_variables(MultiquoteMultiline), class_variables)
315+
316+
def test_multiline_argument(self):
317+
class MultilineArgument:
318+
bar: str = (
319+
"This is a multiline argument"
320+
" that should not be included in the docstring"
321+
)
322+
"""biz baz"""
323+
324+
class_variables = {"bar": {"comment": "biz baz"}}
325+
self.assertEqual(get_class_variables(MultilineArgument), class_variables)
326+
303327
def test_single_quote_multiline(self):
304328
class SingleQuoteMultiline:
305329
bar: int = 0

0 commit comments

Comments
 (0)