Skip to content

Sort rule codes in CLI #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_parse_error_code_str():
("TOR102", {"TOR102"}),
("TOR102,TOR101", {"TOR102", "TOR101"}),
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
(None, GET_ALL_ERROR_CODES() - exclude_set),
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
]
for case, expected in cases:
assert expected == process_error_code_str(case)
18 changes: 6 additions & 12 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def GET_ALL_ERROR_CODES():
for cls in ALL_VISITOR_CLS:
assert issubclass(cls, TorchVisitor)
codes |= {error.error_code for error in cls.ERRORS}
return codes
return sorted(codes)


@functools.cache
Expand All @@ -62,15 +62,12 @@ def expand_error_codes(codes):
def construct_visitor(cls):
if cls is TorchDeprecatedSymbolsVisitor:
return cls(DEPRECATED_CONFIG_PATH)
else:
return cls()

return cls()


def GET_ALL_VISITORS():
out = []
for v in ALL_VISITOR_CLS:
out.append(construct_visitor(v))
return out
return [construct_visitor(v) for v in ALL_VISITOR_CLS]


def get_visitors_with_error_codes(error_codes):
Expand All @@ -87,10 +84,7 @@ def get_visitors_with_error_codes(error_codes):
break
if not found:
raise AssertionError(f"Unknown error code: {error_code}")
out = []
for cls in visitor_classes:
out.append(construct_visitor(cls))
return out
return [construct_visitor(cls) for cls in visitor_classes]


def process_error_code_str(code_str):
Expand All @@ -100,7 +94,7 @@ def process_error_code_str(code_str):
# Default when --select is not provided.
if code_str is None:
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
return GET_ALL_ERROR_CODES() - exclude_set
return set(GET_ALL_ERROR_CODES()) - exclude_set

raw_codes = [s.strip() for s in code_str.split(",")]

Expand Down