Skip to content

Commit c1dc94c

Browse files
[tk] Switch symbolic/shape placeholders to sympy (#375)
By using sympy instead of the stand-in symbolic support, we get full support for index (partial) evaluation, dynamic dimensions and validation checks.
1 parent d6821d3 commit c1dc94c

File tree

13 files changed

+951
-716
lines changed

13 files changed

+951
-716
lines changed

python/shark_turbine/kernel/_support/indexing.py

Lines changed: 345 additions & 228 deletions
Large diffs are not rendered by default.

python/shark_turbine/kernel/_support/tracing.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
import torch.fx as fx
2121

2222
from .indexing import (
23-
BoundedSymbolicValue,
23+
backed_sym_index_type,
24+
BoundedRelation,
25+
IndexExpr,
2426
Grid,
2527
KernelBuffer,
26-
sym_0,
28+
SymIndex,
2729
)
2830

2931
from ..lang.types import (
@@ -98,10 +100,17 @@ class KernelTracer(SubgraphTracer):
98100
# Register our custom proxies.
99101
def proxy(self, node: fx.Node) -> fx.Proxy:
100102
t = node.type
101-
if t is not None and issubclass(t, KernelBuffer):
102-
return KernelBufferProxy(node, self, t)
103+
if t is not None:
104+
if issubclass(t, KernelBuffer):
105+
return KernelBufferProxy(node, self, t)
103106
return super().proxy(node)
104107

108+
def create_arg(self, a):
109+
# Let IndexExpr persist as arguments.
110+
if isinstance(a, IndexExpr):
111+
return a
112+
return super().create_arg(a)
113+
105114

106115
class CapturedTrace:
107116
def __init__(self, region_graph: RegionGraph, root_graph: str):
@@ -163,23 +172,28 @@ def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
163172
super().__init__(eager=False)
164173
self.region_graph = region_graph
165174
self.grid_type = grid_type
175+
self.current_thread_types = [
176+
backed_sym_index_type(BoundedRelation(0, n, upper_inclusive=False))
177+
for n in grid_type.symbolic_shape
178+
]
166179

167180
### ========================================================================
168181
### Core Operations
169182
### ========================================================================
170183

171184
def handle_thread_program_id(self, op, axis: int) -> Index:
172-
grid_shape = self.grid_type.symbolic_shape
173-
if axis < 0 or axis >= len(grid_shape):
185+
grid_types = self.current_thread_types
186+
if axis < 0 or axis >= len(grid_types):
174187
raise IndexError(
175-
f"Illegal index into grid of rank {len(grid_shape)}: {axis}"
188+
f"Illegal index into grid of rank {len(grid_types)}: {axis}"
176189
)
190+
177191
proxy = self.region_graph.create_proxy(
178192
"call_function",
179193
op,
180194
args=(axis,),
181195
kwargs={},
182-
type_expr=BoundedSymbolicValue.bound(sym_0, grid_shape[axis]),
196+
type_expr=grid_types[axis],
183197
)
184198
return proxy
185199

python/shark_turbine/kernel/compiler/analysis.py

Lines changed: 0 additions & 252 deletions
This file was deleted.

python/shark_turbine/kernel/compiler/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
NDEBUG = False
2+
3+
14
class CodegenError(Exception):
25
...
36

0 commit comments

Comments
 (0)