|
16 | 16 | LOGGER = logging.getLogger(__name__)
|
17 | 17 |
|
18 | 18 |
|
| 19 | +def pytest_generate_tests(metafunc): |
| 20 | + # Dynamically generate test cases from paths |
| 21 | + if "checker_source_path" in metafunc.fixturenames: |
| 22 | + files = list(FIXTURES_PATH.glob("**/checker/*.py")) |
| 23 | + metafunc.parametrize( |
| 24 | + "checker_source_path", files, ids=[file_name.stem for file_name in files] |
| 25 | + ) |
| 26 | + if "codemod_source_path" in metafunc.fixturenames: |
| 27 | + files = list(FIXTURES_PATH.glob("**/codemod/*.py")) |
| 28 | + metafunc.parametrize( |
| 29 | + "codemod_source_path", files, ids=[file_name.stem for file_name in files] |
| 30 | + ) |
| 31 | + if "case" in metafunc.fixturenames: |
| 32 | + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) |
| 33 | + cases = [ |
| 34 | + ("ALL", GET_ALL_ERROR_CODES()), |
| 35 | + ("ALL,TOR102", GET_ALL_ERROR_CODES()), |
| 36 | + ("TOR102", {"TOR102"}), |
| 37 | + ("TOR102,TOR101", {"TOR102", "TOR101"}), |
| 38 | + ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), |
| 39 | + (None, set(GET_ALL_ERROR_CODES()) - exclude_set), |
| 40 | + ] |
| 41 | + metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases]) |
| 42 | + |
| 43 | + |
19 | 44 | def _checker_results(s):
|
20 | 45 | checker = TorchChecker(None, s)
|
21 | 46 | return [f"{line}:{col} {msg}" for line, col, msg, _ in checker.run()]
|
22 | 47 |
|
23 | 48 |
|
24 |
| -def _codemod_results(source_path): |
25 |
| - with open(source_path) as source: |
26 |
| - code = source.read() |
| 49 | +def _codemod_results(source_path: Path): |
| 50 | + code = source_path.read_text() |
27 | 51 | config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
|
28 |
| - context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) |
| 52 | + context = TorchCodemod(codemod.CodemodContext(filename=str(source_path)), config) |
29 | 53 | new_module = codemod.transform_module(context, code)
|
30 | 54 | if isinstance(new_module, codemod.TransformSuccess):
|
31 | 55 | return new_module.code
|
32 |
| - elif isinstance(new_module, codemod.TransformFailure): |
| 56 | + if isinstance(new_module, codemod.TransformFailure): |
33 | 57 | raise new_module.error
|
34 | 58 |
|
35 | 59 |
|
36 | 60 | def test_empty():
|
37 | 61 | assert _checker_results([""]) == []
|
38 | 62 |
|
39 | 63 |
|
40 |
| -def test_checker_fixtures(): |
41 |
| - for source_path in FIXTURES_PATH.glob("**/checker/*.py"): |
42 |
| - LOGGER.info("Testing %s", source_path.relative_to(Path.cwd())) |
43 |
| - expected_path = str(source_path)[:-2] + "txt" |
44 |
| - expected_results = [] |
45 |
| - with open(expected_path) as expected: |
46 |
| - for line in expected: |
47 |
| - expected_results.append(line.rstrip()) |
| 64 | +def test_checker_fixtures(checker_source_path: Path): |
| 65 | + expected_path = checker_source_path.with_suffix(".txt") |
| 66 | + expected_results = expected_path.read_text().splitlines() |
48 | 67 |
|
49 |
| - with open(source_path) as source: |
50 |
| - assert _checker_results(source.readlines()) == expected_results |
| 68 | + assert ( |
| 69 | + _checker_results(checker_source_path.read_text().splitlines(keepends=True)) |
| 70 | + == expected_results |
| 71 | + ) |
51 | 72 |
|
52 | 73 |
|
53 |
| -def test_codemod_fixtures(): |
54 |
| - for source_path in FIXTURES_PATH.glob("**/codemod/*.py"): |
55 |
| - LOGGER.info("Testing %s", source_path.relative_to(Path.cwd())) |
56 |
| - expected_path = source_path.with_suffix(".py.out") |
57 |
| - expected_results = expected_path.read_text() |
58 |
| - assert _codemod_results(source_path) == expected_results |
| 74 | +def test_codemod_fixtures(codemod_source_path: Path): |
| 75 | + expected_path = codemod_source_path.with_suffix(".py.out") |
| 76 | + expected_results = expected_path.read_text() |
| 77 | + assert _codemod_results(codemod_source_path) == expected_results |
59 | 78 |
|
60 | 79 |
|
61 | 80 | def test_errorcodes_distinct():
|
62 | 81 | visitors = GET_ALL_VISITORS()
|
63 | 82 | seen = set()
|
64 | 83 | for visitor in visitors:
|
65 | 84 | LOGGER.info("Checking error code for %s", visitor.__class__.__name__)
|
66 |
| - errors = visitor.ERRORS |
67 |
| - for e in errors: |
| 85 | + for e in visitor.ERRORS: |
68 | 86 | assert e.error_code not in seen
|
69 | 87 | seen.add(e.error_code)
|
70 | 88 |
|
71 | 89 |
|
72 |
| -def test_parse_error_code_str(): |
73 |
| - exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) |
74 |
| - cases = [ |
75 |
| - ("ALL", GET_ALL_ERROR_CODES()), |
76 |
| - ("ALL,TOR102", GET_ALL_ERROR_CODES()), |
77 |
| - ("TOR102", {"TOR102"}), |
78 |
| - ("TOR102,TOR101", {"TOR102", "TOR101"}), |
79 |
| - ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), |
80 |
| - (None, set(GET_ALL_ERROR_CODES()) - exclude_set), |
81 |
| - ] |
82 |
| - for case, expected in cases: |
83 |
| - assert expected == process_error_code_str(case) |
| 90 | +def test_parse_error_code_str(case, expected): |
| 91 | + assert process_error_code_str(case) == expected |
0 commit comments