Skip to content

Commit a44aabb

Browse files
authored
Swap to using pyright (#259)
1 parent 9f914ea commit a44aabb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+440
-443
lines changed

.github/workflows/lint.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ jobs:
4747
- name: Install lint dependencies
4848
run: ./lint.sh install
4949

50-
- name: Run lint checks
51-
run: ./lint.sh check
52-
5350
- name: Install pre-commit
5451
run: |
5552
python -m pip install --upgrade pip

.pre-commit-config.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ repos:
3131
- repo: local
3232
hooks:
3333
- id: check-main-in-examples
34-
name: Ensure all example scripts have main()
34+
name: ensure all example scripts have main()
3535
entry: python scripts/lint_examples_main.py
3636
language: system
3737
files: ^examples/.*\.py$
3838

39-
- repo: https://github.com/facebook/pyre-check
40-
rev: 24abb8681e9b5a93700cfd2a404de2506d0f9689
41-
hooks:
42-
- id: pyre-check-no-python
39+
- repo: https://github.com/RobertCraigie/pyright-python
40+
rev: v1.1.403
41+
hooks:
42+
- id: pyright
43+
language: system

.pyre_configuration

Lines changed: 0 additions & 17 deletions
This file was deleted.

benchmarks/run.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def check_and_setup_tritonbench() -> None:
6363
"""Check if tritonbench is installed and install it from GitHub if not."""
6464
# Check if tritonbench is already installed
6565
try:
66-
import tritonbench
66+
import tritonbench # pyright: ignore[reportMissingImports]
6767

6868
return # Already installed
6969
except ImportError:
@@ -126,7 +126,7 @@ def check_and_setup_tritonbench() -> None:
126126

127127
# Verify installation worked
128128
try:
129-
import tritonbench # noqa: F401
129+
import tritonbench # noqa: F401 # pyright: ignore[reportMissingImports]
130130

131131
print(
132132
f"Tritonbench installed successfully with {install_flag}.",
@@ -196,7 +196,9 @@ def main() -> None:
196196

197197
# Import tritonbench components
198198
try:
199-
from tritonbench.utils.parser import get_parser # pyre-ignore[21]
199+
from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports]
200+
get_parser,
201+
)
200202
except ImportError:
201203
print(
202204
"Error: Could not import tritonbench. Make sure it's in the path.",
@@ -223,16 +225,16 @@ def main() -> None:
223225
tb_args = tb_parser.parse_args(tritonbench_args)
224226

225227
# Register the Helion kernel with tritonbench BEFORE importing the operator
226-
from tritonbench.utils.triton_op import ( # pyre-ignore[21]
228+
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
227229
register_benchmark_mannually,
228230
)
229231

230232
# Create the benchmark method
231-
def create_helion_method( # pyre-ignore[3]
232-
kernel_func: Callable[..., Any], # pyre-ignore[2]
233+
def create_helion_method(
234+
kernel_func: Callable[..., Any],
233235
) -> Callable[..., Any]:
234-
def helion_method( # pyre-ignore[3]
235-
self: Any, # pyre-ignore[2]
236+
def helion_method(
237+
self: Any,
236238
*args: Any,
237239
) -> Callable[..., Any]:
238240
"""Helion implementation."""
@@ -246,7 +248,7 @@ def helion_method( # pyre-ignore[3]
246248
if isinstance(attr, Kernel):
247249
attr.reset()
248250

249-
def _inner() -> Callable[..., Any]: # pyre-ignore[3]
251+
def _inner() -> Callable[..., Any]:
250252
return kernel_func(*args)
251253

252254
return _inner

examples/all_gather_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None:
181181
dist_group = dist.group.WORLD
182182
if dist_group is None:
183183
raise RuntimeError("No distributed group available")
184-
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul(
184+
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( # pyright: ignore[reportCallIssue]
185185
golden_a, [b], gather_dim=0, group_name=dist_group.group_name
186186
)
187187
torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1)

examples/attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,9 @@ def attention(
6666
return out.view(q_in.size())
6767

6868

69-
attention_dynamic: object = helion.kernel(
69+
attention_dynamic: object = helion.kernel( # pyright: ignore[reportCallIssue]
7070
attention.fn,
71-
# pyre-fixme[6]
72-
configs=attention.configs,
71+
configs=attention.configs, # pyright: ignore[reportArgumentType]
7372
static_shapes=False,
7473
)
7574

helion/_compiler/ast_extension.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class ExtendedAST:
3939
subclassing each AST node class and mixing in this one.
4040
"""
4141

42-
# pyre-ignore[13]
4342
_fields: tuple[str, ...]
4443

4544
def __init__(
@@ -132,8 +131,7 @@ class Wrapper(ExtendedAST, cls):
132131

133132

134133
def create(cls: type[_T], **fields: object) -> _T:
135-
# pyre-ignore[28]
136-
result = get_wrapper_cls(cls)(**fields, _location=current_location())
134+
result = get_wrapper_cls(cls)(**fields, _location=current_location()) # pyright: ignore[reportCallIssue]
137135
assert isinstance(result, ExtendedAST)
138136
result._location.to_ast(result)
139137
return typing.cast("_T", result)
@@ -165,16 +163,16 @@ def statement_from_string(template: str, **placeholders: ast.AST) -> ast.stmt:
165163

166164
def _replace(node: _R) -> _R:
167165
if isinstance(node, list):
168-
return [_replace(item) for item in node]
166+
return [_replace(item) for item in node] # pyright: ignore[reportReturnType]
169167
if not isinstance(node, ast.AST):
170168
return node
171169
if isinstance(node, ast.Name) and node.id in placeholders:
172-
return placeholders[node.id]
170+
return placeholders[node.id] # pyright: ignore[reportReturnType]
173171
cls = get_wrapper_cls(type(node))
174-
return location.to_ast(
172+
return location.to_ast( # pyright: ignore[reportReturnType]
175173
cls(
176174
**{field: _replace(getattr(node, field)) for field in node._fields},
177-
_location=location,
175+
_location=location, # pyright: ignore[reportCallIssue]
178176
)
179177
)
180178

@@ -196,11 +194,10 @@ def convert(node: ast.AST) -> ast.AST:
196194
# some nodes like arguments lack location information
197195
location = current_location()
198196
with location:
199-
# pyre-ignore[28]
200197
return cls(
201198
**{field: convert(getattr(node, field)) for field in node._fields},
202199
**{attr: getattr(node, attr) for attr in node._attributes},
203-
_location=location,
200+
_location=location, # pyright: ignore[reportCallIssue]
204201
)
205202
elif isinstance(node, list):
206203
return [convert(item) for item in node]
@@ -234,23 +231,23 @@ def visit(self, node: ast.AST) -> ast.AST:
234231
)
235232

236233

237-
class _TupleParensRemovedUnparser(ast._Unparser): # pyre-ignore[11]
238-
def visit_Tuple(self, node) -> None: # pyre-ignore[2]
234+
class _TupleParensRemovedUnparser(
235+
ast._Unparser # pyright: ignore[reportAttributeAccessIssue]
236+
):
237+
def visit_Tuple(self, node) -> None:
239238
if _needs_to_remove_tuple_parens and isinstance(
240239
getattr(node, "ctx", None), ast.Store
241240
):
242241
if len(node.elts) == 1: # single-element tuple
243-
self.traverse(node.elts[0]) # pyre-ignore[16]
244-
self.write(",") # pyre-ignore[16]
242+
self.traverse(node.elts[0])
243+
self.write(",")
245244
else: # multi-element tuple
246-
self.interleave( # pyre-ignore[16]
247-
lambda: self.write(", "), self.traverse, node.elts
248-
)
245+
self.interleave(lambda: self.write(", "), self.traverse, node.elts)
249246
return
250247
# For everything else fall back to default behavior
251-
super().visit_Tuple(node) # pyre-ignore[16]
248+
super().visit_Tuple(node)
252249

253250

254251
def unparse(ast_obj: ast.AST) -> str:
255252
unparser = _TupleParensRemovedUnparser()
256-
return unparser.visit(ast_obj) # pyre-ignore[16]
253+
return unparser.visit(ast_obj)

helion/_compiler/ast_read_writes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class ReadWrites(typing.NamedTuple):
4343
reads: dict[str, int]
4444
writes: dict[str, int]
4545

46-
def __iter__(self) -> typing.Iterator[str]:
46+
def __iter__(self) -> typing.Iterator[str]: # pyright: ignore[reportIncompatibleMethodOverride]
4747
return iter({**self.reads, **self.writes})
4848

4949
@staticmethod

helion/_compiler/compile_environment.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
161161
# )
162162
# TODO(jansel): I was hoping the above would work, seems like some decomps require concrete values
163163
# to determine zeroness. Figure out a better way to do this.
164-
# pyre-ignore[29]
164+
165165
self.shape_env.var_to_val[sym._sympy_()] = sympy.Integer(hint)
166166
assert isinstance(sym._sympy_(), sympy.Symbol)
167167
self.debug_shape_renames[sym._sympy_()] = sympy.Symbol(debug_name, integer=True)
@@ -191,8 +191,8 @@ def cached_create_unbacked_symint(
191191
Returns:
192192
A consistent unbacked symint for the given key
193193
"""
194-
# pyre-ignore[16]
195-
key = tuple([x._sympy_() if hasattr(x, "_sympy_") else x for x in key])
194+
195+
key = tuple([x._sympy_() if hasattr(x, "_sympy_") else x for x in key]) # pyright: ignore[reportAttributeAccessIssue]
196196
result = self._symint_cache.get(key)
197197
if result is None:
198198
result = self.create_unbacked_symint(hint)
@@ -237,9 +237,9 @@ def to_fake(self, obj: object, origin: Origin) -> object:
237237
return [self.to_fake(e, origin) for e in obj]
238238
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
239239
return type(obj)(
240-
**{ # pyre-ignore[6]
240+
**{
241241
k: self.to_fake(e, origin)
242-
for k, e in obj._asdict().items() # pyre-ignore[16]
242+
for k, e in obj._asdict().items() # pyright: ignore[reportAttributeAccessIssue]
243243
}
244244
)
245245
if isinstance(obj, tuple):
@@ -248,10 +248,10 @@ def to_fake(self, obj: object, origin: Origin) -> object:
248248
return {k: self.to_fake(e, origin) for k, e in obj.items()}
249249
if dataclasses.is_dataclass(obj):
250250
return dataclasses.replace(
251-
obj,
251+
obj, # pyright: ignore[reportArgumentType]
252252
**{
253253
k: self.to_fake(getattr(obj, k), origin)
254-
for k in obj.__dataclass_fields__ # pyre-ignore[16]
254+
for k in obj.__dataclass_fields__
255255
},
256256
)
257257

@@ -289,8 +289,8 @@ def size_hint(self, n: int | torch.SymInt) -> int:
289289
# If the size is a symbolic expression with unbacked symbols, then the shape environment
290290
# hint will be wrong since we assign a default value to unbacked symbols. Return a default hint.
291291
return 8192
292-
# pyre-ignore[6]
293-
return int(self.shape_env.size_hint(n._sympy_()))
292+
293+
return int(self.shape_env.size_hint(n._sympy_())) # pyright: ignore[reportArgumentType]
294294
assert isinstance(n, int)
295295
return n
296296

@@ -514,5 +514,4 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
514514

515515

516516
def _has_unbacked(expr: sympy.Expr) -> bool:
517-
# pyre-ignore[16]
518-
return any(n.name.startswith("u") for n in expr.free_symbols)
517+
return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue]

helion/_compiler/device_function.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str:
221221
if expr in expr_to_origin:
222222
return self._lift_sympy_arg(expr)
223223
replacements = {}
224-
for sym in sorted(expr.free_symbols, key=lambda x: x.name):
224+
for sym in sorted(expr.free_symbols, key=lambda x: x.name): # pyright: ignore[reportAttributeAccessIssue]
225225
assert isinstance(sym, sympy.Symbol)
226226
if sym in self.expr_to_var_info:
227227
replacements[sym] = sympy.Symbol(
@@ -254,7 +254,7 @@ def _lift_sympy_arg(self, expr: sympy.Expr) -> str:
254254
def user_sympy_expr(self, expr: sympy.Expr) -> str:
255255
"""A sympy expression that flows into user computations."""
256256
replacements = {}
257-
for sym in sorted(expr.free_symbols, key=lambda s: s.name):
257+
for sym in sorted(expr.free_symbols, key=lambda s: s.name): # pyright: ignore[reportAttributeAccessIssue]
258258
assert isinstance(sym, sympy.Symbol)
259259
block_idx = CompileEnvironment.current().get_block_id(sym)
260260
if block_idx is not None:
@@ -474,7 +474,9 @@ def dead_code_elimination(self) -> None:
474474

475475
# drop any unused args
476476
args_to_remove = {
477-
arg.name for arg in self.arguments if arg.name not in rw.reads
477+
arg.name
478+
for arg in self.arguments
479+
if arg.name not in rw.reads # pyright: ignore[reportPossiblyUnboundVariable]
478480
}
479481
if args_to_remove:
480482
self.arguments = [

0 commit comments

Comments
 (0)