Skip to content

Commit 9c3085b

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Disallow direct reads from TMEM refs
Leaving this piece of code there was an oversight in the CL that introduced plgpu.async_load_tmem. PiperOrigin-RevId: 781919096
1 parent 7813749 commit 9c3085b

File tree

1 file changed

+0
-11
lines changed

1 file changed

+0
-11
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,17 +1458,6 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...
14581458
def _get_lowering_rule(
14591459
ctx: LoweringRuleContext, x_ref, *leaves, tree, optimized=True
14601460
):
1461-
if isinstance(x_ref, tcgen05.TMEMRef):
1462-
transforms = jax.tree.unflatten(tree, leaves)
1463-
x_tmem, transforms = _handle_transforms(
1464-
ctx, x_ref, transforms, handle_transposes=False, handle_reshapes=False,
1465-
)
1466-
if transforms:
1467-
raise NotImplementedError(
1468-
f"Unimplemented transforms for TMEM refs. {transforms=}"
1469-
)
1470-
return x_tmem.load(layout=ctx.out_layout_hint)
1471-
14721461
if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref):
14731462
raise TypeError(f"Can only load from references (got {x_ref}).")
14741463
dtype = ctx.avals_out[0].dtype

0 commit comments

Comments
 (0)