Skip to content

Commit 8aef63b

Browse files
authored
Sort rule codes in CLI (#68)
Sort the rule codes in the CLI help is more intuitive to the user. _Before_: ```text --select SELECT Comma-separated list of rules to enable or 'ALL' to enable all rules. Available rules: TOR001, TOR105, TOR402, TOR103, TOR401, TOR201, TOR004, TOR104, TOR501, TOR102, TOR202, TOR403, TOR203, TOR002, TOR901, TOR003, TOR101. Defaults to all except for TOR3, TOR4, TOR9. ``` _After_: ```text --select SELECT Comma-separated list of rules to enable or 'ALL' to enable all rules. Available rules: TOR001, TOR002, TOR003, TOR004, TOR101, TOR102, TOR103, TOR104, TOR105, TOR201, TOR202, TOR203, TOR401, TOR402, TOR403, TOR501, TOR901. Defaults to all except for TOR3, TOR4, TOR9. ``` Changes: - Sort rule codes - Minor code modernisation for readability (list comprehension, superfluous else)
1 parent a0d3b2e commit 8aef63b

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

tests/test_torchfix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_parse_error_code_str():
7777
("TOR102", {"TOR102"}),
7878
("TOR102,TOR101", {"TOR102", "TOR101"}),
7979
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
80-
(None, GET_ALL_ERROR_CODES() - exclude_set),
80+
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
8181
]
8282
for case, expected in cases:
8383
assert expected == process_error_code_str(case)

torchfix/torchfix.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def GET_ALL_ERROR_CODES():
4646
for cls in ALL_VISITOR_CLS:
4747
assert issubclass(cls, TorchVisitor)
4848
codes |= {error.error_code for error in cls.ERRORS}
49-
return codes
49+
return sorted(codes)
5050

5151

5252
@functools.cache
@@ -62,15 +62,12 @@ def expand_error_codes(codes):
6262
def construct_visitor(cls):
6363
if cls is TorchDeprecatedSymbolsVisitor:
6464
return cls(DEPRECATED_CONFIG_PATH)
65-
else:
66-
return cls()
65+
66+
return cls()
6767

6868

6969
def GET_ALL_VISITORS():
70-
out = []
71-
for v in ALL_VISITOR_CLS:
72-
out.append(construct_visitor(v))
73-
return out
70+
return [construct_visitor(v) for v in ALL_VISITOR_CLS]
7471

7572

7673
def get_visitors_with_error_codes(error_codes):
@@ -87,10 +84,7 @@ def get_visitors_with_error_codes(error_codes):
8784
break
8885
if not found:
8986
raise AssertionError(f"Unknown error code: {error_code}")
90-
out = []
91-
for cls in visitor_classes:
92-
out.append(construct_visitor(cls))
93-
return out
87+
return [construct_visitor(cls) for cls in visitor_classes]
9488

9589

9690
def process_error_code_str(code_str):
@@ -100,7 +94,7 @@ def process_error_code_str(code_str):
10094
# Default when --select is not provided.
10195
if code_str is None:
10296
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
103-
return GET_ALL_ERROR_CODES() - exclude_set
97+
return set(GET_ALL_ERROR_CODES()) - exclude_set
10498

10599
raw_codes = [s.strip() for s in code_str.split(",")]
106100

0 commit comments

Comments
 (0)