From 76088bc9fbfd7b0aa2ab4fefd9b3bff1de1b3cd7 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 23 Mar 2025 21:45:09 +0000 Subject: [PATCH 1/2] fix logic for finding print statement --- pyproject.toml | 2 +- pytest_examples/run_code.py | 37 ++++++++++++++++---------------- tests/test_insert_print.py | 42 +++++++++++++++++++++++++++++-------- uv.lock | 2 +- 4 files changed, 53 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2719fa8..9ffdee5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ include = ["/README.md", "/Makefile", "/pytest_examples", "/tests"] [project] name = "pytest-examples" -version = "0.0.16" +version = "0.0.17" description = "Pytest plugin for testing examples in docstrings and markdown files." authors = [ {name = "Samuel Colvin", email = "s@muelcolvin.com"}, diff --git a/pytest_examples/run_code.py b/pytest_examples/run_code.py index 903f414..cf19434 100644 --- a/pytest_examples/run_code.py +++ b/pytest_examples/run_code.py @@ -156,8 +156,8 @@ def __call__(self, *args: Any, sep: str = ' ', **kwargs: Any) -> None: frame = inspect.stack()[parent_frame_id] if self._include_file(frame, args): - # -1 to account for the line number being 1-indexed - s = PrintStatement(frame.lineno, sep, [Arg(arg) for arg in args]) + lineno = self._find_line_number(frame) + s = PrintStatement(lineno, sep, [Arg(arg) for arg in args]) self.statements.append(s) def _include_file(self, frame: inspect.FrameInfo, args: Sequence[Any]) -> bool: @@ -166,6 +166,17 @@ def _include_file(self, frame: inspect.FrameInfo, args: Sequence[Any]) -> bool: else: return self.file.samefile(frame.filename) + def _find_line_number(self, inspect_frame: inspect.FrameInfo) -> int: + """Find the line number of the print statement in the file that is being executed.""" + frame = inspect_frame.frame + while True: + if self.file.samefile(frame.f_code.co_filename): + return frame.f_lineno + elif frame.f_back: + frame = frame.f_back + else: + raise RuntimeError(f'Could not find line number of print statement at {inspect_frame}') + class InsertPrintStatements: def __init__( @@ -256,18 +267,6 @@ def _insert_print_args( triple_quotes_prefix_re = re.compile('^ *(?:"{3}|\'{3})', re.MULTILINE) -def find_print_line(lines: list[str], line_no: int) -> int: - """For 3.7 we have to reverse through lines to find the print statement lint.""" - return line_no - - for back in range(100): - new_line_no = line_no - back - m = re.search(r'^ *print\(', lines[new_line_no - 1]) - if m: - return new_line_no - return line_no - - def remove_old_print(lines: list[str], line_index: int) -> None: """Remove the old print statement.""" try: @@ -294,12 +293,12 @@ def remove_old_print(lines: list[str], line_index: int) -> None: def find_print_location(example: CodeExample, line_no: int) -> tuple[int, int]: """Find the line and column of the print statement. - :param example: the `CodeExample` - :param line_no: The line number on which the print statement starts (or approx on 3.7) - :return: tuple if `(line, column)` of the print statement - """ - # For 3.7 we have to reverse through lines to find the print statement lint + Args: + example: the `CodeExample` + line_no: The line number on which the print statement starts or approx + Return: tuple if `(line, column)` of the print statement + """ m = ast.parse(example.source, filename=example.path.name) return find_print(m, line_no) or (line_no, 0) diff --git a/tests/test_insert_print.py b/tests/test_insert_print.py index 911c521..48b6892 100644 --- a/tests/test_insert_print.py +++ b/tests/test_insert_print.py @@ -1,5 +1,7 @@ from __future__ import annotations as _annotations +import sys + import pytest from _pytest.outcomes import Failed @@ -397,8 +399,6 @@ def main(): def test_run_main_print(tmp_path, eval_example): - # note this file is no written here as it's not required - md_file = tmp_path / 'test.md' python_code = """ main_called = False @@ -408,7 +408,8 @@ def main(): print(1, 2, 3) #> 1 2 3 """ - example = CodeExample.create(python_code, path=md_file) + # note this file is no written here as it's not required + example = CodeExample.create(python_code, path=tmp_path / 'test.md') eval_example.set_config(line_length=30) module_dict = eval_example.run_print_check(example, call='main') @@ -416,8 +417,6 @@ def main(): def test_run_main_print_async(tmp_path, eval_example): - # note this file is no written here as it's not required - md_file = tmp_path / 'test.md' python_code = """ main_called = False @@ -427,7 +426,8 @@ async def main(): print(1, 2, 3) #> 1 2 3 """ - example = CodeExample.create(python_code, path=md_file) + # note this file is no written here as it's not required + example = CodeExample.create(python_code, path=tmp_path / 'test.md') eval_example.set_config(line_length=30) module_dict = eval_example.run_print_check(example, call='main') @@ -435,14 +435,13 @@ async def main(): def test_custom_include_print(tmp_path, eval_example): - # note this file is no written here as it's not required - md_file = tmp_path / 'test.md' python_code = """ print('yes') #> yes print('no') """ - example = CodeExample.create(python_code, path=md_file) + # note this file is no written here as it's not required + example = CodeExample.create(python_code, path=tmp_path / 'test.md') eval_example.set_config(line_length=30) def custom_include_print(path, frame, args): @@ -451,3 +450,28 @@ def custom_include_print(path, frame, args): eval_example.include_print = custom_include_print eval_example.run_print_check(example, call='main') + + +def test_print_different_file(tmp_path, eval_example): + other_file = tmp_path / 'other.py' + other_code = """ +def does_print(): + print('hello') + """ + other_file.write_text(other_code) + sys.path.append(str(tmp_path)) + python_code = """ +import other + +other.does_print() +#> hello +""" + example = CodeExample.create(python_code, path=tmp_path / 'test.md') + + eval_example.include_print = lambda p, f, a: True + + eval_example.run_print_check(example, call='main') + + other_file.write_text(('\n' * 30) + other_code) + + eval_example.run_print_check(example, call='main') diff --git a/uv.lock b/uv.lock index 178d1f5..7980926 100644 --- a/uv.lock +++ b/uv.lock @@ -331,7 +331,7 @@ wheels = [ [[package]] name = "pytest-examples" -version = "0.0.16" +version = "0.0.17" source = { editable = "." } dependencies = [ { name = "black" }, From 006a9f9542d1ff39c47d71c1640d5b0ad386886e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 23 Mar 2025 21:50:13 +0000 Subject: [PATCH 2/2] fix test --- tests/test_insert_print.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_insert_print.py b/tests/test_insert_print.py index 48b6892..37f8a9f 100644 --- a/tests/test_insert_print.py +++ b/tests/test_insert_print.py @@ -472,6 +472,7 @@ def does_print(): eval_example.run_print_check(example, call='main') + del sys.modules['other'] other_file.write_text(('\n' * 30) + other_code) eval_example.run_print_check(example, call='main')