Skip to content

Commit b99c016

Browse files
committed
Split checker and codemod tests into individual test cases via metafunc
Few pathlib cleanups
1 parent 63cf152 commit b99c016

File tree

1 file changed

+43
-35
lines changed

1 file changed

+43
-35
lines changed

tests/test_torchfix.py

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,68 +16,76 @@
1616
LOGGER = logging.getLogger(__name__)
1717

1818

19+
def pytest_generate_tests(metafunc):
20+
# Dynamically generate test cases from paths
21+
if "checker_source_path" in metafunc.fixturenames:
22+
files = list(FIXTURES_PATH.glob("**/checker/*.py"))
23+
metafunc.parametrize(
24+
"checker_source_path", files, ids=[file_name.stem for file_name in files]
25+
)
26+
if "codemod_source_path" in metafunc.fixturenames:
27+
files = list(FIXTURES_PATH.glob("**/codemod/*.py"))
28+
metafunc.parametrize(
29+
"codemod_source_path", files, ids=[file_name.stem for file_name in files]
30+
)
31+
if "case" in metafunc.fixturenames:
32+
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
33+
cases = [
34+
("ALL", GET_ALL_ERROR_CODES()),
35+
("ALL,TOR102", GET_ALL_ERROR_CODES()),
36+
("TOR102", {"TOR102"}),
37+
("TOR102,TOR101", {"TOR102", "TOR101"}),
38+
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
39+
(None, GET_ALL_ERROR_CODES() - exclude_set),
40+
]
41+
metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases])
42+
43+
1944
def _checker_results(s):
2045
checker = TorchChecker(None, s)
2146
return [f"{line}:{col} {msg}" for line, col, msg, _ in checker.run()]
2247

2348

24-
def _codemod_results(source_path):
25-
with open(source_path) as source:
26-
code = source.read()
49+
def _codemod_results(source_path: Path):
50+
code = source_path.read_text()
2751
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
28-
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
52+
context = TorchCodemod(codemod.CodemodContext(filename=str(source_path)), config)
2953
new_module = codemod.transform_module(context, code)
3054
if isinstance(new_module, codemod.TransformSuccess):
3155
return new_module.code
32-
elif isinstance(new_module, codemod.TransformFailure):
56+
if isinstance(new_module, codemod.TransformFailure):
3357
raise new_module.error
3458

3559

3660
def test_empty():
3761
assert _checker_results([""]) == []
3862

3963

40-
def test_checker_fixtures():
41-
for source_path in FIXTURES_PATH.glob("**/checker/*.py"):
42-
LOGGER.info("Testing %s", source_path.relative_to(Path.cwd()))
43-
expected_path = str(source_path)[:-2] + "txt"
44-
expected_results = []
45-
with open(expected_path) as expected:
46-
for line in expected:
47-
expected_results.append(line.rstrip())
64+
def test_checker_fixtures(checker_source_path: Path):
65+
expected_path = checker_source_path.with_suffix(".txt")
66+
expected_results = expected_path.read_text().splitlines()
4867

49-
with open(source_path) as source:
50-
assert _checker_results(source.readlines()) == expected_results
68+
assert (
69+
_checker_results(checker_source_path.read_text().splitlines(keepends=True))
70+
== expected_results
71+
)
5172

5273

53-
def test_codemod_fixtures():
54-
for source_path in FIXTURES_PATH.glob("**/codemod/*.py"):
55-
LOGGER.info("Testing %s", source_path.relative_to(Path.cwd()))
56-
expected_path = source_path.with_suffix(".py.out")
57-
expected_results = expected_path.read_text()
58-
assert _codemod_results(source_path) == expected_results
74+
def test_codemod_fixtures(codemod_source_path: Path):
75+
expected_path = codemod_source_path.with_suffix(".py.out")
76+
expected_results = expected_path.read_text()
77+
assert _codemod_results(codemod_source_path) == expected_results
5978

6079

6180
def test_errorcodes_distinct():
6281
visitors = GET_ALL_VISITORS()
6382
seen = set()
6483
for visitor in visitors:
6584
LOGGER.info("Checking error code for %s", visitor.__class__.__name__)
66-
errors = visitor.ERRORS
67-
for e in errors:
85+
for e in visitor.ERRORS:
6886
assert e not in seen
6987
seen.add(e)
7088

7189

72-
def test_parse_error_code_str():
73-
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
74-
cases = [
75-
("ALL", GET_ALL_ERROR_CODES()),
76-
("ALL,TOR102", GET_ALL_ERROR_CODES()),
77-
("TOR102", {"TOR102"}),
78-
("TOR102,TOR101", {"TOR102", "TOR101"}),
79-
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
80-
(None, GET_ALL_ERROR_CODES() - exclude_set),
81-
]
82-
for case, expected in cases:
83-
assert expected == process_error_code_str(case)
90+
def test_parse_error_code_str(case, expected):
91+
assert process_error_code_str(case) == expected

0 commit comments

Comments
 (0)