Skip to content

Commit a616229

Browse files
Fix approx 150 type errors and enable mypy in core. (#508)
Excludes kernel for the moment as it has quite a number of typing abnormalities. (I wasn't keeping detailed count but there were a non trivial number of bugs -- not just typing adjustments)
1 parent 24b1872 commit a616229

File tree

30 files changed

+193
-279
lines changed

30 files changed

+193
-279
lines changed

.github/workflows/test.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ jobs:
5757
run: |
5858
pytest -n 4 serving/
5959
60-
- name: MyPy Type Checking
60+
- name: MyPy Type Checking Core
6161
if: ${{ !cancelled() }}
6262
run: |
63-
mypy serving/
63+
(cd core && mypy)
64+
65+
- name: MyPy Type Checking Serving
66+
if: ${{ !cancelled() }}
67+
run: |
68+
(cd serving && mypy)

core/mypy.ini

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,21 @@
11
[mypy]
22

3+
explicit_package_bases = True
34
mypy_path = $MYPY_CONFIG_FILE_DIR
4-
packages = shark_turbine.aot,shark_turbine.dynamo,shark_turbine.support
5+
packages = shark_turbine
6+
7+
# Missing typing stubs for iree.compiler.
8+
[mypy-iree.compiler.*]
9+
ignore_missing_imports = True
10+
11+
# Missing typing stubs for iree.runtime.
12+
[mypy-iree.runtime.*]
13+
ignore_missing_imports = True
14+
15+
# fx_importer needs to be fixed upstream.
16+
[mypy-shark_turbine.importers.fx_importer.*]
17+
ignore_errors = True
18+
19+
# TODO: Fix all typing errors in TK.
20+
[mypy-shark_turbine.kernel.*]
21+
ignore_errors = True

core/shark_turbine/aot/builtins/globals.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
GlobalAttributes,
2323
)
2424

25-
from ..support.utils import (
25+
from torch.utils._pytree import (
2626
TreeSpec,
2727
tree_flatten,
2828
tree_map,
@@ -48,7 +48,7 @@ def __init__(
4848
):
4949
if attrs is None:
5050
attrs = GlobalAttributes(
51-
mutable=mutable,
51+
mutable=bool(mutable),
5252
external=external,
5353
external_scope=external_scope,
5454
name_mapper=name_mapper,
@@ -85,7 +85,7 @@ def __init__(
8585
):
8686
if attrs is None:
8787
attrs = GlobalAttributes(
88-
mutable=mutable,
88+
mutable=bool(mutable),
8989
external=external,
9090
external_scope=external_scope,
9191
name_mapper=name_mapper,
@@ -135,7 +135,7 @@ def __init__(
135135
):
136136
if attrs is None:
137137
attrs = GlobalAttributes(
138-
mutable=mutable,
138+
mutable=bool(mutable),
139139
external=external,
140140
external_scope=external_scope,
141141
name_mapper=name_mapper,

core/shark_turbine/aot/builtins/jittable.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
GraphModule,
2222
)
2323
from torch.fx.passes.shape_prop import TensorMetadata
24+
from torch.utils._pytree import (
25+
tree_flatten,
26+
tree_unflatten,
27+
)
2428

2529
# TODO: Switch to upstream fx_importer vs local fork when ready.
2630
# from iree.compiler.extras.fx_importer import (
@@ -51,16 +55,12 @@
5155
util_d,
5256
)
5357

58+
from ...support.logging import aot_logger as logger
59+
5460
from ..passes import (
5561
functorch_functionalize,
5662
)
5763

58-
from ..support.utils import (
59-
logger,
60-
tree_flatten,
61-
tree_unflatten,
62-
)
63-
6464
from ..support.ir_utils import (
6565
ModuleBuilder,
6666
)
@@ -153,10 +153,8 @@ def __init__(
153153
self,
154154
wrapped_f,
155155
*,
156-
decompose_ops: Optional[List[torch._ops.OpOverload]] = None,
157-
decomposition_table: Optional[
158-
Dict[torch._ops.OpOverload, Callable[..., Any]]
159-
] = None,
156+
decompose_ops: Optional[List[Any]] = None,
157+
decomposition_table: Optional[Dict[Any, Callable[..., Any]]] = None,
160158
constraints: Optional[List[Constraint]] = None,
161159
function_name: Optional[str] = None,
162160
passes: Sequence[str] = DEFAULT_PASSES,
@@ -311,6 +309,7 @@ def flat_wrapped_f(*args):
311309
assert len(flat_ir_results) == len(result_tensor_infos)
312310
flat_py_results = []
313311
for ir_result, result_tensor_info in zip(flat_ir_results, result_tensor_infos):
312+
assert result_tensor_info is not None
314313
(dtype,) = result_tensor_info
315314
native_ir_result = type_converter.materialize_torch_to_native(ir_result)
316315
if dtype is not None:
@@ -478,5 +477,5 @@ def _extract_graph_output_metadata(
478477
dtype = tensor_meta.dtype
479478
elif fake_val is not None:
480479
dtype = fake_val.dtype
481-
output_metadata.append((dtype,))
480+
output_metadata.append((dtype,) if dtype is not None else None)
482481
return out_spec, output_metadata

core/shark_turbine/aot/compiled_module.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# See https://llvm.org/LICENSE.txt for license information.
66
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
77

8-
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
8+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
99

1010
import enum
1111
import inspect
@@ -26,6 +26,7 @@
2626
PassManager,
2727
StringAttr,
2828
)
29+
from ..support.logging import aot_logger as logger
2930
from ..transforms.general.custom_op_expansion import ExpandCustomOpsPass
3031

3132
from .support.procedural import (
@@ -38,9 +39,6 @@
3839
ModuleBuilder,
3940
)
4041

41-
from .support.utils import (
42-
logger,
43-
)
4442

4543
__all__ = [
4644
"CompiledModule",
@@ -155,21 +153,21 @@ def export_procs(self) -> Generator[Tuple[str, ExportProcDef], None, None]:
155153
return filter(
156154
lambda kv_tuple: isinstance(kv_tuple[1], ExportProcDef),
157155
self.all_exports.items(),
158-
)
156+
) # type: ignore
159157

160158
@property
161159
def py_only_defs(self) -> Generator[Tuple[str, PyOnlyDef], None, None]:
162160
return filter(
163161
lambda kv_tuple: isinstance(kv_tuple[1], PyOnlyDef),
164162
self.all_exports.items(),
165-
)
163+
) # type: ignore
166164

167165
@property
168166
def globals_defs(self) -> Generator[Tuple[str, GlobalsDef], None, None]:
169167
return filter(
170168
lambda kv_tuple: isinstance(kv_tuple[1], GlobalsDef),
171169
self.all_exports.items(),
172-
)
170+
) # type: ignore
173171

174172
def def_attribute(self, key, value):
175173
# Some decorators, the only thing we do is convert them to PyOnlyDef.
@@ -209,11 +207,11 @@ def def_export_proc(self, name, f) -> ExportProcDef:
209207
file_line_loc = None
210208
try:
211209
sourcefile = inspect.getsourcefile(f)
212-
_, linenums = sourcelines = inspect.getsourcelines(f)
213-
if sourcefile and linenums:
214-
file_line_loc = [sourcefile, linenums[0]]
215-
except TypeError:
216-
pass
210+
_, linenum = sourcelines = inspect.getsourcelines(f)
211+
except OSError:
212+
...
213+
else:
214+
file_line_loc = (sourcefile or "<unnamed>", linenum)
217215

218216
sig = inspect.signature(f)
219217
if len(sig.parameters) < 1:
@@ -267,19 +265,18 @@ def __init__(
267265
self.module_builder = module_builder
268266
# The shadow dict holds instance attributes. We stash them here and the
269267
# Program instance itself arbitrates access via getattr/setattr.
270-
self.shadow_dict = dict()
268+
self.shadow_dict: dict[str, Any] = dict()
271269
self.current_import_phase = ImportPhase.TORCH_IR
272270

273271

274272
################################################################################
275273
# Live reference accounting
276274
################################################################################
277275

278-
_all_compiled_module_class_infos: Dict[
276+
_all_compiled_module_class_infos: weakref.WeakKeyDictionary[
279277
"CompiledModuleMeta", CompiledModuleClassInfo
280278
] = weakref.WeakKeyDictionary()
281-
282-
_all_compiled_module_instance_infos: Dict[
279+
_all_compiled_module_instance_infos: weakref.WeakKeyDictionary[
283280
"CompiledModule", CompiledModuleInstanceInfo
284281
] = weakref.WeakKeyDictionary()
285282

@@ -292,7 +289,7 @@ def __init__(
292289
_metaclass_setup_complete = False
293290

294291

295-
@property
292+
@property # type: ignore
296293
def _blackhole_instance_attribute(self):
297294
# We're not here.
298295
raise AttributeError

core/shark_turbine/aot/exporter.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from typing import Optional, Sequence, Union
7+
from typing import Any, Optional, Sequence, Union
88
import functools
99
import io
1010
from pathlib import Path
@@ -116,8 +116,8 @@ def compile(
116116
save_to = Path(save_to)
117117
output = Output.open_file(str(save_to))
118118
else:
119-
assert isinstance(output, Output)
120119
output = save_to
120+
assert isinstance(output, Output)
121121

122122
target_backends = (
123123
target_backends
@@ -153,7 +153,7 @@ def compile(
153153
# Decorator which explicitly exports a function.
154154
# TODO: Make this a public API on CompiledModule.
155155
# See https://github.com/nod-ai/SHARK-Turbine/issues/126
156-
def export_proc(f=None, *, signature: Sequence[AbstractTypedef]) -> ExportProcDef:
156+
def export_proc(f=None, *, signature: Sequence[AbstractTypedef]) -> Any:
157157
if f is None:
158158
return functools.partial(export_proc, signature=signature)
159159
return ExportProcDef(f.__name__, f, signature=signature)
@@ -177,19 +177,22 @@ def export(mdl: ModuleLike, *example_args: torch.Tensor) -> ExportOutput:
177177
An ExportOutput object that wraps the compilation and provides
178178
easy access.
179179
"""
180+
TransformedModule: Any
180181
if isinstance(mdl, torch.nn.Module):
182+
nn_module = mdl
181183
signature = [abstractify(t) for t in example_args]
182184

183-
class Exported(CompiledModule, export_name=mdl._get_name()):
184-
params = export_parameters(mdl)
185+
class Exported(CompiledModule, export_name=nn_module._get_name()):
186+
params = export_parameters(nn_module)
185187

186188
@export_proc(signature=signature)
187189
def main(self, *args):
188-
return jittable(mdl.forward)(*args)
190+
return jittable(nn_module.forward)(*args)
189191

192+
TransformedModule = Exported
190193
else:
191194
assert isinstance(mdl, CompiledModuleMeta)
192-
Exported = mdl
195+
TransformedModule = mdl
193196

194197
session = Session()
195198
# There are some bugs with respect to Session/context interop that we
@@ -201,5 +204,5 @@ def main(self, *args):
201204
else:
202205
context = Context()
203206

204-
cm = Exported(context=context, import_to="import")
207+
cm = TransformedModule(context=context, import_to="import")
205208
return ExportOutput(session, cm, importer_uses_session=importer_uses_session)

core/shark_turbine/aot/passes/functorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from typing import Callable
7+
from typing import Any, Callable
88

99
import torch
1010
from torch.fx import (
@@ -40,8 +40,8 @@
4040
# since that does not result in load bearing information loss. Note that
4141
# ONNX applies this post export, which suffers from the loss of output
4242
# destructuring rewrites that torch.export does.
43-
def functorch_functionalize(gm: GraphModule, *args) -> GraphModule:
44-
functionalized_callable = _functionalize_callabale(gm)
43+
def functorch_functionalize(gm_callable: Any, *args) -> GraphModule:
44+
functionalized_callable = _functionalize_callabale(gm_callable)
4545
# TODO: There is more of a dance needed if the user has entered with a fake_mode.
4646
with proxy_tensor.maybe_disable_fake_tensor_mode():
4747
new_gm = proxy_tensor.make_fx(

core/shark_turbine/aot/support/ir_utils.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515

1616
from ...importers.fx_importer import (
1717
ContextCache,
18-
)
19-
20-
from ...importers.utils import (
21-
RefTracker as FxRefTracker,
18+
Empty,
19+
EmptyType,
20+
RefTracker,
2221
)
2322

2423
from ...dynamo.type_conversion import (
@@ -58,10 +57,7 @@
5857
TORCH_DTYPE_TO_IREE_TYPE,
5958
)
6059

61-
from .utils import (
62-
RefTracker,
63-
logger,
64-
)
60+
from ...support.logging import aot_logger as logger
6561

6662
###############################################################################
6763
# Configuration
@@ -150,7 +146,7 @@ def __init__(self, module_op: Operation):
150146
# Usually the FxImporter makes a new ref tracker for each invocation,
151147
# but we want to preserve it across individual JIT evaluations so
152148
# as to better intern tensors to attributes.
153-
self.fx_py_attr_tracker = FxRefTracker()
149+
self.fx_py_attr_tracker = RefTracker()
154150
self.native_type_converter = NativeTypeConverter(self.context)
155151

156152
def handle_mlir_error(self, op: Operation, e: MLIRError, message: str):

core/shark_turbine/aot/support/procedural/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
)
1515

1616
from contextlib import contextmanager
17+
import threading
1718

1819
import torch
20+
from torch.utils._pytree import tree_map
1921

2022
from ....support.ir_imports import (
2123
F32Type,
@@ -34,12 +36,8 @@
3436
ModuleBuilder,
3537
)
3638

37-
from ..utils import (
38-
thread_state,
39-
tree_map,
40-
)
41-
4239
ShapedTypeDynamicSizeSentinel = ShapedType.get_dynamic_size()
40+
_thread_state = threading.local()
4341

4442
###############################################################################
4543
# Tracing intrinsics
@@ -71,9 +69,9 @@ def handle_assignment(self, scope, target, updated_value):
7169

7270
def _trace_scopes() -> List[IrTrace]:
7371
try:
74-
trace_scopes = thread_state.trace_scopes
72+
trace_scopes = _thread_state.trace_scopes
7573
except AttributeError:
76-
trace_scopes = thread_state.trace_scopes = []
74+
trace_scopes = _thread_state.trace_scopes = []
7775
return trace_scopes
7876

7977

0 commit comments

Comments
 (0)