Skip to content

Commit 28c3be2

Browse files
authored
Enforce ANN/PGH lints (#315)
One thing we lost in the pyre => pyright switch is enforcing that everything is typed (I had disabled ANN since the errors were rudundant with pyre).
1 parent 345fb5c commit 28c3be2

File tree

12 files changed

+56
-30
lines changed

12 files changed

+56
-30
lines changed

benchmarks/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,9 @@ def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
263263

264264
# Create the benchmark method
265265
def helion_method(
266-
self: Any,
267-
*args: Any,
268-
) -> Callable[..., Any]:
266+
self: object,
267+
*args: object,
268+
) -> Callable[..., object]:
269269
"""Helion implementation."""
270270

271271
# Reset all Helion kernels before creating the benchmark function

examples/all_gather_matmul.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import os
4-
from typing import Any
54

65
import torch
76
import torch.distributed as dist
@@ -118,7 +117,7 @@ def helion_all_gather_matmul(
118117
b: torch.Tensor,
119118
a_out: torch.Tensor | None = None,
120119
progress: torch.Tensor | None = None,
121-
**kwargs: Any,
120+
**kwargs: int,
122121
) -> tuple[torch.Tensor, torch.Tensor]:
123122
configs = {
124123
"SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1),

examples/segment_reduction.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def segmented_reduction_helion(
4848

4949

5050
@triton.jit
51-
def combine_fn_triton(left_values, left_indices, right_values, right_indices):
51+
def combine_fn_triton(
52+
left_values: tl.tensor,
53+
left_indices: tl.tensor,
54+
right_values: tl.tensor,
55+
right_indices: tl.tensor,
56+
) -> tuple[tl.tensor, tl.tensor]:
5257
same_segment = left_indices == right_indices
5358
combined_values = tl.where(same_segment, left_values + right_values, right_values)
5459
combined_indices = right_indices
@@ -67,13 +72,13 @@ def combine_fn_triton(left_values, left_indices, right_values, right_indices):
6772
)
6873
@triton.jit
6974
def _segmented_reduction_triton(
70-
index, # the input index tensor
71-
in_ptr, # the input tensor
72-
out_ptr, # the output value tensor
75+
index: tl.tensor, # the input index tensor
76+
in_ptr: tl.tensor, # the input tensor
77+
out_ptr: tl.tensor, # the output value tensor
7378
E: tl.constexpr, # Number of elements in the input tensor (1d)
7479
C: tl.constexpr, # Number of features in the input tensor (2d)
7580
BLOCK_SIZE: tl.constexpr, # Block size for the scan
76-
):
81+
) -> None:
7782
# Triton version adapted from
7883
# https://github.com/fishmingyu/GeoT/blob/main/geot/triton/seg_reduction.py
7984
pid = tl.program_id(axis=0)
@@ -101,20 +106,24 @@ def _segmented_reduction_triton(
101106
tl.atomic_add(out_ptr + idxs * C + feature_id, result_values, mask & segment_start)
102107

103108

104-
def segmented_reduction_triton(indices, input_data, num_nodes):
109+
def segmented_reduction_triton(
110+
indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int
111+
) -> torch.Tensor:
105112
E, C = input_data.shape
106113
output = torch.zeros(
107114
(num_nodes, C), dtype=input_data.dtype, device=input_data.device
108115
)
109116

110-
def grid(META):
117+
def grid(META: dict[str, int]) -> tuple[int, ...]:
111118
return (triton.cdiv(E, META["BLOCK_SIZE"]) * C,)
112119

113120
_segmented_reduction_triton[grid](indices, input_data, output, E, C)
114121
return output
115122

116123

117-
def segmented_reduction_pytorch(indices, input_data, num_nodes):
124+
def segmented_reduction_pytorch(
125+
indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int
126+
) -> torch.Tensor:
118127
# Run PyTorch reference (scatter_add equivalent)
119128
num_features = input_data.size(1)
120129
pytorch_output = torch.zeros(
@@ -126,7 +135,7 @@ def segmented_reduction_pytorch(indices, input_data, num_nodes):
126135
return pytorch_output
127136

128137

129-
def main():
138+
def main() -> None:
130139
num_nodes = 100
131140
num_edges = 2000
132141
num_features = 128

helion/_compiler/ast_extension.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def visit(self, node: ast.AST) -> ast.AST:
234234
class _TupleParensRemovedUnparser(
235235
ast._Unparser # pyright: ignore[reportAttributeAccessIssue]
236236
):
237-
def visit_Tuple(self, node) -> None:
237+
def visit_Tuple(self, node: ast.Tuple) -> None:
238238
if _needs_to_remove_tuple_parens and isinstance(
239239
getattr(node, "ctx", None), ast.Store
240240
):

helion/_compiler/ast_read_writes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def visit_Expr(self, node: ast.Expr) -> ast.Expr | None:
232232
def dead_assignment_elimination(
233233
body: list[ast.AST],
234234
dce_vars: list[str],
235-
num_iterations=8,
235+
num_iterations: int = 8,
236236
input_rw: ReadWrites | None = None,
237237
) -> None:
238238
"""

helion/_compiler/inductor_lowering_extra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _register_inductor_lowering(
4646
)
4747

4848
@functools.wraps(decomp_fn) # pyright: ignore[reportArgumentType]
49-
def wrapped(*args: Any, **kwargs: Any) -> object:
49+
def wrapped(*args: object, **kwargs: object) -> object:
5050
args = list(args) # pyright: ignore[reportAssignmentType]
5151
kwargs = dict(kwargs)
5252
unpacked = False

helion/_compiler/traceback_compat.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def normalize(off: int) -> int:
153153
return None # fallback - no fancy anchors
154154

155155

156-
def format_frame_summary(frame_summary): # type: ignore[override]
156+
def format_frame_summary(frame_summary: traceback.FrameSummary) -> str: # type: ignore[override]
157157
"""Backport of Python 3.11's traceback.StackSummary.format_frame_summary()."""
158158

159159
_ensure_original_line(frame_summary)
@@ -170,18 +170,18 @@ def format_frame_summary(frame_summary): # type: ignore[override]
170170
stripped_line = frame_summary.line.strip()
171171
row.append(f" {stripped_line}\n")
172172

173-
line = frame_summary._original_line
173+
line = frame_summary._original_line # type: ignore[attr-defined]
174174
orig_line_len = len(line)
175175
frame_line_len = len(frame_summary.line.lstrip())
176176
stripped_characters = orig_line_len - frame_line_len
177177

178-
if frame_summary.colno is not None and frame_summary.end_colno is not None:
179-
start_offset = _byte_offset_to_character_offset(line, frame_summary.colno)
180-
end_offset = _byte_offset_to_character_offset(line, frame_summary.end_colno)
178+
if frame_summary.colno is not None and frame_summary.end_colno is not None: # type: ignore[attr-defined]
179+
start_offset = _byte_offset_to_character_offset(line, frame_summary.colno) # type: ignore[attr-defined]
180+
end_offset = _byte_offset_to_character_offset(line, frame_summary.end_colno) # type: ignore[attr-defined]
181181
code_segment = line[start_offset:end_offset]
182182

183183
anchors = None
184-
if frame_summary.lineno == frame_summary.end_lineno:
184+
if frame_summary.lineno == frame_summary.end_lineno: # type: ignore[attr-defined]
185185
with suppress(Exception):
186186
anchors = _extract_caret_anchors_from_line_segment(code_segment)
187187
else:

helion/_compiler/type_printer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def append(self, text: str) -> None:
5151
class ASTPrinter(_TupleParensRemovedUnparser):
5252
_indent: int
5353

54-
def __init__(self, *args, **kwargs) -> None:
54+
def __init__(self, *args: object, **kwargs: object) -> None:
5555
super().__init__(*args, **kwargs)
5656
assert self._source == []
5757
self._source = self.output = OutputLines(self)

helion/autotuner/base_search.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import time
1515
from typing import TYPE_CHECKING
1616
from typing import NamedTuple
17+
from typing import NoReturn
18+
19+
if TYPE_CHECKING:
20+
from triton.runtime.jit import JITFunction
1721

1822
from torch._inductor.runtime.triton_compat import OutOfResources
1923
from torch._inductor.runtime.triton_compat import PTXASError
@@ -152,7 +156,7 @@ def extract_launcher(
152156
grid: tuple[int, ...],
153157
*args: object,
154158
**kwargs: object,
155-
):
159+
) -> NoReturn:
156160
"""Custom launcher that extracts arguments instead of executing."""
157161
raise _ExtractedLaunchArgs(triton_kernel, grid, args, kwargs)
158162

@@ -524,7 +528,18 @@ def _mark_complete(self) -> bool:
524528
class _ExtractedLaunchArgs(Exception):
525529
"""Exception that carries kernel launch arguments for precompiler extraction."""
526530

527-
def __init__(self, triton_kernel, grid, args, kwargs):
531+
kernel: JITFunction[object]
532+
grid: object
533+
args: tuple[object, ...]
534+
kwargs: dict[str, object]
535+
536+
def __init__(
537+
self,
538+
triton_kernel: JITFunction[object],
539+
grid: object,
540+
args: tuple[object, ...],
541+
kwargs: dict[str, object],
542+
) -> None:
528543
super().__init__()
529544
self.kernel = triton_kernel
530545
self.grid = grid

helion/runtime/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def default_launcher(
5050
*args: object,
5151
num_warps: int,
5252
num_stages: int,
53-
):
53+
) -> object:
5454
"""Default launcher function that executes the kernel immediately."""
5555
return triton_kernel.run(
5656
*args, grid=grid, warmup=False, num_warps=num_warps, num_stages=num_stages

0 commit comments

Comments
 (0)