Skip to content

Commit 5699fc8

Browse files
authored
Split tests into individual test cases (#67)
`metafunc.parametrize` can be used to create more granular test cases by dynamically generating test fixtures (i.e. arguments). This in our case is helpful for debugging failing test cases, by knowing directly which of the paths are failing. <img width="300" src="https://github.com/user-attachments/assets/538d9c4e-f549-4d5f-a9db-1ef0fd68cb4f"> _Before_ <img width="300" src="https://github.com/user-attachments/assets/aa9cf9ea-6745-4c20-9f3f-97fc62d13020"> _After_ Changes: - Split checker and codemod tests into individual test cases via metafunc - Few pathlib cleanups - Removed obviated logger calls in these functions
1 parent 311cdd7 commit 5699fc8

File tree

1 file changed

+43
-35
lines changed

1 file changed

+43
-35
lines changed

tests/test_torchfix.py

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,68 +16,76 @@
1616
LOGGER = logging.getLogger(__name__)
1717

1818

19+
def pytest_generate_tests(metafunc):
20+
# Dynamically generate test cases from paths
21+
if "checker_source_path" in metafunc.fixturenames:
22+
files = list(FIXTURES_PATH.glob("**/checker/*.py"))
23+
metafunc.parametrize(
24+
"checker_source_path", files, ids=[file_name.stem for file_name in files]
25+
)
26+
if "codemod_source_path" in metafunc.fixturenames:
27+
files = list(FIXTURES_PATH.glob("**/codemod/*.py"))
28+
metafunc.parametrize(
29+
"codemod_source_path", files, ids=[file_name.stem for file_name in files]
30+
)
31+
if "case" in metafunc.fixturenames:
32+
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
33+
cases = [
34+
("ALL", GET_ALL_ERROR_CODES()),
35+
("ALL,TOR102", GET_ALL_ERROR_CODES()),
36+
("TOR102", {"TOR102"}),
37+
("TOR102,TOR101", {"TOR102", "TOR101"}),
38+
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
39+
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
40+
]
41+
metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases])
42+
43+
1944
def _checker_results(s):
2045
checker = TorchChecker(None, s)
2146
return [f"{line}:{col} {msg}" for line, col, msg, _ in checker.run()]
2247

2348

24-
def _codemod_results(source_path):
25-
with open(source_path) as source:
26-
code = source.read()
49+
def _codemod_results(source_path: Path):
50+
code = source_path.read_text()
2751
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
28-
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
52+
context = TorchCodemod(codemod.CodemodContext(filename=str(source_path)), config)
2953
new_module = codemod.transform_module(context, code)
3054
if isinstance(new_module, codemod.TransformSuccess):
3155
return new_module.code
32-
elif isinstance(new_module, codemod.TransformFailure):
56+
if isinstance(new_module, codemod.TransformFailure):
3357
raise new_module.error
3458

3559

3660
def test_empty():
3761
assert _checker_results([""]) == []
3862

3963

40-
def test_checker_fixtures():
41-
for source_path in FIXTURES_PATH.glob("**/checker/*.py"):
42-
LOGGER.info("Testing %s", source_path.relative_to(Path.cwd()))
43-
expected_path = str(source_path)[:-2] + "txt"
44-
expected_results = []
45-
with open(expected_path) as expected:
46-
for line in expected:
47-
expected_results.append(line.rstrip())
64+
def test_checker_fixtures(checker_source_path: Path):
65+
expected_path = checker_source_path.with_suffix(".txt")
66+
expected_results = expected_path.read_text().splitlines()
4867

49-
with open(source_path) as source:
50-
assert _checker_results(source.readlines()) == expected_results
68+
assert (
69+
_checker_results(checker_source_path.read_text().splitlines(keepends=True))
70+
== expected_results
71+
)
5172

5273

53-
def test_codemod_fixtures():
54-
for source_path in FIXTURES_PATH.glob("**/codemod/*.py"):
55-
LOGGER.info("Testing %s", source_path.relative_to(Path.cwd()))
56-
expected_path = source_path.with_suffix(".py.out")
57-
expected_results = expected_path.read_text()
58-
assert _codemod_results(source_path) == expected_results
74+
def test_codemod_fixtures(codemod_source_path: Path):
75+
expected_path = codemod_source_path.with_suffix(".py.out")
76+
expected_results = expected_path.read_text()
77+
assert _codemod_results(codemod_source_path) == expected_results
5978

6079

6180
def test_errorcodes_distinct():
6281
visitors = GET_ALL_VISITORS()
6382
seen = set()
6483
for visitor in visitors:
6584
LOGGER.info("Checking error code for %s", visitor.__class__.__name__)
66-
errors = visitor.ERRORS
67-
for e in errors:
85+
for e in visitor.ERRORS:
6886
assert e.error_code not in seen
6987
seen.add(e.error_code)
7088

7189

72-
def test_parse_error_code_str():
73-
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
74-
cases = [
75-
("ALL", GET_ALL_ERROR_CODES()),
76-
("ALL,TOR102", GET_ALL_ERROR_CODES()),
77-
("TOR102", {"TOR102"}),
78-
("TOR102,TOR101", {"TOR102", "TOR101"}),
79-
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
80-
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
81-
]
82-
for case, expected in cases:
83-
assert expected == process_error_code_str(case)
90+
def test_parse_error_code_str(case, expected):
91+
assert process_error_code_str(case) == expected

0 commit comments

Comments
 (0)