|
20 | 20 | import torch.fx as fx
|
21 | 21 |
|
22 | 22 | from .indexing import (
|
23 |
| - BoundedSymbolicValue, |
| 23 | + backed_sym_index_type, |
| 24 | + BoundedRelation, |
| 25 | + IndexExpr, |
24 | 26 | Grid,
|
25 | 27 | KernelBuffer,
|
26 |
| - sym_0, |
| 28 | + SymIndex, |
27 | 29 | )
|
28 | 30 |
|
29 | 31 | from ..lang.types import (
|
@@ -98,10 +100,17 @@ class KernelTracer(SubgraphTracer):
|
98 | 100 | # Register our custom proxies.
|
99 | 101 | def proxy(self, node: fx.Node) -> fx.Proxy:
|
100 | 102 | t = node.type
|
101 |
| - if t is not None and issubclass(t, KernelBuffer): |
102 |
| - return KernelBufferProxy(node, self, t) |
| 103 | + if t is not None: |
| 104 | + if issubclass(t, KernelBuffer): |
| 105 | + return KernelBufferProxy(node, self, t) |
103 | 106 | return super().proxy(node)
|
104 | 107 |
|
| 108 | + def create_arg(self, a): |
| 109 | + # Let IndexExpr persist as arguments. |
| 110 | + if isinstance(a, IndexExpr): |
| 111 | + return a |
| 112 | + return super().create_arg(a) |
| 113 | + |
105 | 114 |
|
106 | 115 | class CapturedTrace:
|
107 | 116 | def __init__(self, region_graph: RegionGraph, root_graph: str):
|
@@ -163,23 +172,28 @@ def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
|
163 | 172 | super().__init__(eager=False)
|
164 | 173 | self.region_graph = region_graph
|
165 | 174 | self.grid_type = grid_type
|
| 175 | + self.current_thread_types = [ |
| 176 | + backed_sym_index_type(BoundedRelation(0, n, upper_inclusive=False)) |
| 177 | + for n in grid_type.symbolic_shape |
| 178 | + ] |
166 | 179 |
|
167 | 180 | ### ========================================================================
|
168 | 181 | ### Core Operations
|
169 | 182 | ### ========================================================================
|
170 | 183 |
|
171 | 184 | def handle_thread_program_id(self, op, axis: int) -> Index:
|
172 |
| - grid_shape = self.grid_type.symbolic_shape |
173 |
| - if axis < 0 or axis >= len(grid_shape): |
| 185 | + grid_types = self.current_thread_types |
| 186 | + if axis < 0 or axis >= len(grid_types): |
174 | 187 | raise IndexError(
|
175 |
| - f"Illegal index into grid of rank {len(grid_shape)}: {axis}" |
| 188 | + f"Illegal index into grid of rank {len(grid_types)}: {axis}" |
176 | 189 | )
|
| 190 | + |
177 | 191 | proxy = self.region_graph.create_proxy(
|
178 | 192 | "call_function",
|
179 | 193 | op,
|
180 | 194 | args=(axis,),
|
181 | 195 | kwargs={},
|
182 |
| - type_expr=BoundedSymbolicValue.bound(sym_0, grid_shape[axis]), |
| 196 | + type_expr=grid_types[axis], |
183 | 197 | )
|
184 | 198 | return proxy
|
185 | 199 |
|
|
0 commit comments