Skip to content

Commit 03af274

Browse files
[Mosaic GPU][NFC] Enable the gmem argument to async_copy to be an ir.BlockArgument
This is needed in upcoming changes, where we will create a block in the lowering of the `CustomPrimitiveOp`. Kernel arguments are accessed inside the block via the block argument. However, also inside the lowering of `CustomPrimitiveOp` we will call `async_copy` if the body contains it. `async_copy` creates a tma descriptor and for initializing this descriptor on the host, we need to get to the original `unrealized_conversion_cast`. PiperOrigin-RevId: 781581079
1 parent aa24d5e commit 03af274

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,29 @@ class _DefaultPredicate:
324324
pass
325325

326326

327+
def _find_kernel_argument_for_gmem_ref(
328+
gmem_ref: ir.Value,
329+
) -> builtin.UnrealizedConversionCastOp:
330+
"""Returns the kernel argument value for a given gmem_ref.
331+
332+
The kernel argument is expected to be an unrealized conversion cast. This
333+
function will recursively go up block arguments in case of nested blocks.
334+
"""
335+
if not isinstance(gmem_ref.type, ir.MemRefType):
336+
raise ValueError(f"Expected {gmem_ref} to have a memref type.")
337+
338+
while isinstance(gmem_ref, ir.BlockArgument):
339+
gmem_ref = gmem_ref.owner.owner.operands[gmem_ref.arg_number]
340+
341+
# TODO(apaszke): This is a very approximate check. Improve it!
342+
if not isinstance(gmem_ref.owner.opview, builtin.UnrealizedConversionCastOp):
343+
raise NotImplementedError(
344+
f"Expected {gmem_ref.owner} to be an unrealized conversion cast"
345+
" corresponding to a GMEM kernel argument."
346+
)
347+
return gmem_ref
348+
349+
327350
@dataclasses.dataclass()
328351
class LaunchContext:
329352
module: ir.Module
@@ -406,6 +429,7 @@ def _get_tma_desc(
406429
"add","min","max","inc","dec","and","or","xor"
407430
] | None,
408431
):
432+
gmem_ref = _find_kernel_argument_for_gmem_ref(gmem_ref)
409433
# Using ir.Values in cache keys is a little sketchy, but I think it should
410434
# be fine. Having it in the key will keep it alive, and if comparison and
411435
# hashing is by identity then it should work out.
@@ -599,13 +623,10 @@ def async_copy(
599623
arrive = True # Commit this copy to the async group by default
600624
else:
601625
raise ValueError("Only SMEM <-> GMEM copies supported")
602-
# TODO(apaszke): This is a very approximate check. Improve it!
603-
expected_name = "builtin.unrealized_conversion_cast"
604-
if (
605-
gmem_ref.owner is None
606-
or gmem_ref.owner.opview.OPERATION_NAME != expected_name
607-
):
608-
raise ValueError("GMEM reference in async_copy must be a kernel argument")
626+
627+
# The function below is called only to verify the GMEM ref. The output
628+
# is meant to be ignored.
629+
_find_kernel_argument_for_gmem_ref(gmem_ref)
609630
gmem_ref_ty = ir.MemRefType(gmem_ref.type)
610631
gmem_strides, _ = gmem_ref_ty.get_strides_and_offset()
611632
if gmem_strides != utils.get_contiguous_strides(gmem_ref_ty.shape):

0 commit comments

Comments
 (0)