Skip to content

Commit 457c5b1

Browse files
authored
Lint ao (#1521)
1 parent cc8e80b commit 457c5b1

33 files changed

+1108
-668
lines changed

.github/scripts/github_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@
33
import json
44
import os
55
import warnings
6-
76
from dataclasses import dataclass
8-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
7+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
98
from urllib.error import HTTPError
109
from urllib.parse import quote
1110
from urllib.request import Request, urlopen
1211

13-
1412
GITHUB_API_URL = "https://api.github.com"
1513

1614

.github/scripts/gitutils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
from typing import (
1010
Any,
1111
Callable,
12-
cast,
1312
Dict,
1413
Iterator,
1514
List,
1615
Optional,
1716
Tuple,
1817
TypeVar,
1918
Union,
19+
cast,
2020
)
2121

2222
T = TypeVar("T")
@@ -45,7 +45,7 @@ def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
4545

4646

4747
def _check_output(items: List[str], encoding: str = "utf-8") -> str:
48-
from subprocess import CalledProcessError, check_output, STDOUT
48+
from subprocess import STDOUT, CalledProcessError, check_output
4949

5050
try:
5151
return check_output(items, stderr=STDOUT).decode(encoding)

.github/scripts/label_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""GitHub Label Utilities."""
22

33
import json
4-
54
from functools import lru_cache
6-
from typing import Any, List, Tuple, TYPE_CHECKING, Union
5+
from typing import TYPE_CHECKING, Any, List, Tuple, Union
76

8-
from github_utils import gh_fetch_url_and_headers, GitHubComment
7+
from github_utils import GitHubComment, gh_fetch_url_and_headers
98

109
# TODO: this is a temp workaround to avoid circular dependencies,
1110
# and should be removed once GitHubPR is refactored out of trymerge script.

.github/scripts/trymerge.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,44 +23,41 @@
2323
from typing import (
2424
Any,
2525
Callable,
26-
cast,
2726
Dict,
2827
Iterable,
2928
List,
3029
NamedTuple,
3130
Optional,
3231
Pattern,
3332
Tuple,
33+
cast,
3434
)
3535
from warnings import warn
3636

3737
import yaml
3838
from github_utils import (
39+
GitHubComment,
3940
gh_fetch_json_list,
4041
gh_fetch_merge_base,
4142
gh_fetch_url,
4243
gh_graphql,
4344
gh_post_commit_comment,
4445
gh_post_pr_comment,
4546
gh_update_pr_state,
46-
GitHubComment,
4747
)
48-
4948
from gitutils import (
49+
GitRepo,
5050
are_ghstack_branches_in_sync,
5151
get_git_remote_name,
5252
get_git_repo_dir,
53-
GitRepo,
5453
patterns_to_regex,
5554
retries_decorator,
5655
)
5756
from label_utils import (
5857
gh_add_labels,
5958
gh_remove_label,
60-
has_required_labels,
61-
LABEL_ERR_MSG,
6259
)
63-
from trymerge_explainer import get_revert_message, TryMergeExplainer
60+
from trymerge_explainer import TryMergeExplainer, get_revert_message
6461

6562
# labels
6663
MERGE_IN_PROGRESS_LABEL = "merging"
@@ -1477,7 +1474,7 @@ def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
14771474

14781475

14791476
def checks_to_markdown_bullets(
1480-
checks: List[Tuple[str, Optional[str], Optional[int]]]
1477+
checks: List[Tuple[str, Optional[str], Optional[int]]],
14811478
) -> List[str]:
14821479
return [
14831480
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
@@ -1716,7 +1713,7 @@ def get_readable_drci_results(drci_classifications: Any) -> str:
17161713
try:
17171714
print(f"From Dr.CI checkrun summary: {drci_summary}")
17181715
drci_classifications = json.loads(str(drci_summary))
1719-
except json.JSONDecodeError as error:
1716+
except json.JSONDecodeError:
17201717
warn("Invalid Dr.CI checkrun summary")
17211718
drci_classifications = {}
17221719

@@ -1887,7 +1884,6 @@ def do_revert_prs(
18871884
dry_run: bool = False,
18881885
) -> None:
18891886
# Prepare and push revert commits
1890-
commit_shas: List[str] = []
18911887
for commit_sha, pr in shas_and_prs:
18921888
revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
18931889
revert_msg += extra_msg

.github/scripts/trymerge_explainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import re
33
from typing import List, Optional, Pattern, Tuple
44

5-
65
BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
76

87
CIFLOW_LABEL = re.compile(r"^ciflow/.+")

docs/source/conf.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222

2323
import os
2424
import sys
25-
from docutils.parsers import rst
25+
2626
import pytorch_sphinx_theme
27+
from docutils.parsers import rst
2728

2829
sys.path.append(os.path.abspath("."))
2930

@@ -60,7 +61,7 @@
6061

6162
### TODO: Delete this when we have content
6263
suppress_warnings = [
63-
'toc.unlisted',
64+
"toc.unlisted",
6465
]
6566
###
6667

@@ -169,12 +170,8 @@
169170
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
170171
# See http://stackoverflow.com/a/41184353/3343043
171172

172-
from docutils import nodes
173-
from sphinx import addnodes
174-
from sphinx.util.docfields import TypedField
175173

176174
from custom_directives import CustomCardEnd, CustomCardItem, CustomCardStart
177-
from docutils.parsers import rst
178175

179176
rst.directives.register_directive("customcardstart", CustomCardStart)
180177
rst.directives.register_directive("customcarditem", CustomCardItem)

docs/source/tutorials_source/template_tutorial.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
# -----
3333
#
3434
# Example code (the output below is generated automatically):
35-
#
35+
#
3636
import torch
37+
3738
x = torch.rand(5, 3)
3839
print(x)
3940

@@ -48,7 +49,7 @@
4849
######################################################################
4950
# Conclusion
5051
# ----------
51-
#
52+
#
5253
# Summarize the steps and concepts covered. Highlight key takeaways.
5354
#
5455
# Further Reading
Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,68 @@
1+
import cv2
2+
import matplotlib.pyplot as plt
13
import numpy as np
24
import torch
3-
import matplotlib.pyplot as plt
4-
import cv2
5-
import torch.utils.benchmark as benchmark
6-
75
from torch._inductor import config as inductorconfig
6+
87
inductorconfig.triton.unique_kernel_names = True
98
inductorconfig.coordinate_descent_tuning = True
109
inductorconfig.coordinate_descent_check_all_directions = True
1110

11+
1212
def profiler_runner(path, fn, *args, **kwargs):
1313
with torch.profiler.profile(
14-
activities=[torch.profiler.ProfilerActivity.CPU,
15-
torch.profiler.ProfilerActivity.CUDA],
16-
record_shapes=True) as prof:
14+
activities=[
15+
torch.profiler.ProfilerActivity.CPU,
16+
torch.profiler.ProfilerActivity.CUDA,
17+
],
18+
record_shapes=True,
19+
) as prof:
1720
result = fn(*args, **kwargs)
1821
print(f"Saving trace under {path}")
1922
prof.export_chrome_trace(path)
2023
return result
2124

25+
2226
def show_anns(anns):
2327
if len(anns) == 0:
2428
return
25-
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
29+
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
2630
ax = plt.gca()
2731
ax.set_autoscale_on(False)
2832

29-
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
30-
img[:,:,3] = 0
33+
img = np.ones(
34+
(
35+
sorted_anns[0]["segmentation"].shape[0],
36+
sorted_anns[0]["segmentation"].shape[1],
37+
4,
38+
)
39+
)
40+
img[:, :, 3] = 0
3141
ms = []
3242
for ann in sorted_anns:
33-
m = ann['segmentation']
43+
m = ann["segmentation"]
3444
ms.append(torch.as_tensor(m))
3545
color_mask = np.concatenate([np.random.random(3), [0.35]])
3646
img[m] = color_mask
3747
ax.imshow(img)
3848
return torch.stack(ms)
3949

40-
image = cv2.imread('dog.jpg')
50+
51+
image = cv2.imread("dog.jpg")
4152
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
4253

4354

4455
# from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator
45-
#
56+
#
4657
# sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
4758
# model_type = "vit_h"
4859
device = "cuda"
49-
#
60+
#
5061
# sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
5162
# sam.to(device=device)
5263

53-
from sam2.build_sam import build_sam2
5464
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
65+
from sam2.build_sam import build_sam2
5566

5667
sam2_checkpoint = "checkpoints/sam2_hiera_large.pt"
5768
model_cfg = "sam2_hiera_l.yaml"
@@ -66,7 +77,7 @@ def show_anns(anns):
6677
## TODO: Implement mIoU to allow approximations.
6778
# torch.set_float32_matmul_precision('high')
6879
# torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
69-
##
80+
##
7081

7182
## TODO: Using CUDA graphs can cause numerical differences?
7283
mask_generator.predictor.model.image_encoder = torch.compile(
@@ -93,24 +104,26 @@ def show_anns(anns):
93104
)
94105

95106
# with torch.backends.cuda.sdp_kernel(enable_cudnn=False): #, enable_math=False, enable_mem_efficient=False):
96-
with torch.backends.cuda.sdp_kernel(enable_cudnn=True): #, enable_math=False, enable_mem_efficient=False):
107+
with torch.backends.cuda.sdp_kernel(
108+
enable_cudnn=True
109+
): # , enable_math=False, enable_mem_efficient=False):
97110
# Run thrice for warmup
98111
masks = mask_generator.generate(image)
99112
masks = mask_generator.generate(image)
100113
masks = mask_generator.generate(image)
101-
114+
102115
# Save an example
103-
plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100)
116+
plt.figure(figsize=(image.shape[1] / 100.0, image.shape[0] / 100.0), dpi=100)
104117
plt.imshow(image)
105118
ms = show_anns(masks)
106119
ms_ref = torch.load("dog_mask_fast.pt")
107120
torch.testing.assert_allclose(ms, ms_ref)
108121
print("Masks match reference")
109122
# # torch.save(ms, "dog_mask_fast.pt")
110-
plt.axis('off')
123+
plt.axis("off")
111124
plt.tight_layout()
112-
plt.savefig('dog_mask_fast.png', format='png')
113-
125+
plt.savefig("dog_mask_fast.png", format="png")
126+
114127
# Benchmark
115128
torch.cuda.synchronize()
116129
start_event = torch.cuda.Event(enable_timing=True)
@@ -120,14 +133,18 @@ def show_anns(anns):
120133
masks = mask_generator.generate(image)
121134
end_event.record()
122135
torch.cuda.synchronize()
123-
print(start_event.elapsed_time(end_event) / 10.)
124-
136+
print(start_event.elapsed_time(end_event) / 10.0)
137+
125138
# Save a GPU trace
126-
profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image)
127-
139+
profiler_runner("amg_example_trace.json.gz", mask_generator.generate, image)
140+
128141
# Write out memory usage
129142
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
130143
_, total_memory = torch.cuda.mem_get_info()
131-
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
144+
max_memory_allocated_percentage = int(
145+
100 * (max_memory_allocated_bytes / total_memory)
146+
)
132147
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
133-
print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}")
148+
print(
149+
f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}"
150+
)

0 commit comments

Comments
 (0)