Skip to content

Commit 8f9098f

Browse files
[Mosaic GPU][NFC] Add utils to create and check the SMEM memory space attribute.
PiperOrigin-RevId: 781447752
1 parent 70cdf17 commit 8f9098f

File tree

11 files changed

+67
-77
lines changed

11 files changed

+67
-77
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,12 +494,13 @@ def scratch_view(
494494
runtime scratch buffer.
495495
"""
496496
smem_base = None
497-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
498497
i8 = ir.IntegerType.get_signless(8)
499498
i32 = ir.IntegerType.get_signless(32)
500499
if self.lowering_semantics == mgpu.LoweringSemantics.Lane:
501500
smem_base = gpu_dialect.dynamic_shared_memory(
502-
ir.MemRefType.get((mgpu_utils.DYNAMIC,), i8, memory_space=smem)
501+
ir.MemRefType.get(
502+
(mgpu_utils.DYNAMIC,), i8, memory_space=mgpu_utils.smem()
503+
)
503504
)
504505
views = []
505506
off = initial_used_bytes = self.smem_used_bytes
@@ -508,7 +509,7 @@ def scratch_view(
508509
scratch_ty = ir.MemRefType.get(
509510
s.shape,
510511
mgpu_utils.dtype_to_ir_type(s.dtype),
511-
memory_space=smem,
512+
memory_space=mgpu_utils.smem(),
512513
)
513514
# The below code emission relies on the assumption that the first scratch
514515
# operand provided by Mosaic GPU always begins at the beginning of
@@ -1271,7 +1272,7 @@ def _handle_dtype_bitcast(
12711272
"Data type bitcast is only supported from i8 to other types."
12721273
)
12731274
ref_ty = ir.MemRefType(ref.type)
1274-
if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
1275+
if not mgpu_utils.is_smem_ref(ref_ty):
12751276
raise ValueError(f"Only workgroup memory is supported but got {ref}.")
12761277
if len(ref_ty.shape) != 1:
12771278
raise NotImplementedError(

jax/experimental/mosaic/gpu/core.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def _construct_smem_reftree(
350350
index = ir.IndexType.get()
351351
i32 = ir.IntegerType.get_signless(32)
352352
i64 = ir.IntegerType.get_signless(64)
353-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
354353
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
355354
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
356355
)
@@ -364,7 +363,7 @@ def barrier_memref(num_barriers: int) -> ir.Value:
364363
ir.Type.parse("!mosaic_gpu.barrier")
365364
if lowering_semantics == LoweringSemantics.Warpgroup
366365
else i64,
367-
memory_space=smem,
366+
memory_space=utils.smem(),
368367
)
369368
barrier_memref = _slice_smem(
370369
barrier_ty,
@@ -411,7 +410,7 @@ def ref(member_thunks=member_thunks):
411410
)
412411
case TMEM(shape, dtype, layout=layout, collective=collective, packing=packing):
413412
addr_ref = _slice_smem(
414-
ir.MemRefType.get([], i32, memory_space=smem),
413+
ir.MemRefType.get([], i32, memory_space=utils.smem()),
415414
dynamic_smem,
416415
c(dynamic_smem_offset, index),
417416
lowering_semantics,
@@ -433,7 +432,7 @@ def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout):
433432
case _:
434433
mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype)
435434
tile_smem = _slice_smem(
436-
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
435+
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=utils.smem()),
437436
dynamic_smem,
438437
c(dynamic_smem_offset, index),
439438
lowering_semantics,
@@ -531,18 +530,17 @@ def _launch(
531530
token.type, [token], *grid_vals, *block_vals,
532531
dynamicSharedMemorySize=c(smem_bytes, i32), **cluster_kwargs)
533532
launch_op.body.blocks.append(*([index] * (12 + 2 * len(cluster_kwargs)))) # Append an empty block
534-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
535533
with ir.InsertionPoint(launch_op.body.blocks[0]):
536534
dynamic_smem = gpu.dynamic_shared_memory(
537-
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem)
535+
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=utils.smem())
538536
)
539537

540538
if profiler_spec:
541539
prof_smem = _slice_smem(
542540
ir.MemRefType.get(
543541
(profiler_spec.smem_i32_elements(block=block),),
544542
i32,
545-
memory_space=smem,
543+
memory_space=utils.smem(),
546544
),
547545
dynamic_smem,
548546
c(profiler_start, index),

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,8 @@ def _vector_reduction_op_lowering_rule(
514514
a = _fragmented_array_from_ir(op.vector, layout, is_signed)
515515
match str(op.kind):
516516
case "#vector.kind<add>":
517-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
518517
scratch = _slice_smem(
519-
ir.MemRefType.get([4], element_type, memory_space=smem),
518+
ir.MemRefType.get([4], element_type, memory_space=utils.smem()),
520519
arith.constant(None, op.attributes["offset"]),
521520
)
522521
result = a.reduce("add", range(len(a.shape)), scratch)
@@ -675,7 +674,7 @@ def _transformed_smem_ref_type(
675674
if not transforms and not transposed:
676675
return ref_ty
677676

678-
if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
677+
if not utils.is_smem_ref(ref_ty):
679678
raise ValueError(f"Only workgroup memory is supported but got {ref_ty}.")
680679

681680
shape = ref_ty.shape
@@ -1130,17 +1129,16 @@ def _mgpu_slice_smem_op_lowering_rule(
11301129

11311130
def _slice_smem(result: ir.Type, offset: ir.Value):
11321131
i8 = ir.IntegerType.get_signless(8)
1133-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
11341132
smem_base = gpu.dynamic_shared_memory(
1135-
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem)
1133+
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=utils.smem())
11361134
)
11371135
offset = arith.index_cast(ir.IndexType.get(), offset)
11381136
lowered_result_type = result
11391137
if ir.MemRefType.isinstance(result):
11401138
memref_ty = ir.MemRefType(result)
11411139
if memref_ty.element_type == ir.Type.parse("!mosaic_gpu.barrier"):
11421140
lowered_result_type = ir.MemRefType.get(
1143-
memref_ty.shape, _lowered_barrier_type(), memory_space=smem
1141+
memref_ty.shape, _lowered_barrier_type(), memory_space=utils.smem()
11441142
)
11451143
view = memref.view(lowered_result_type, smem_base, offset, [])
11461144
if result == lowered_result_type:

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2665,7 +2665,7 @@ def transfer_tiled2(
26652665

26662666
if ref_ty.memory_space is None:
26672667
llvm_memory_space = None
2668-
elif ref_ty.memory_space == ir.Attribute.parse("#gpu.address_space<workgroup>"):
2668+
elif utils.is_smem_ref(ref_ty):
26692669
llvm_memory_space = 3
26702670
else:
26712671
raise ValueError(f"Unsupported memory space: {ref_ty.memory_space}")

jax/experimental/mosaic/gpu/inference_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
from jax._src.lib.mlir import ir
2424

25+
from . import utils
26+
2527
MlirOperation = Union[ir.Operation, ir.OpView]
2628

2729
def in_layouts(op: MlirOperation) -> Sequence[ir.Attribute]:
@@ -147,13 +149,11 @@ def should_have_transforms(op: ir.OpView) -> bool:
147149
def is_transformable_smem_memref(v: ir.Value) -> bool:
148150
"""Whether the value is a memref in SMEM on which transforms should be applied."""
149151
barrier_ty = ir.Type.parse("!mosaic_gpu.barrier")
150-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
151152
return (
152153
ir.MemRefType.isinstance(v.type)
153154
# barriers have no business being transformed
154155
and v.type.element_type != barrier_ty # pylint: disable=attribute-error
155-
and v.type.memory_space is not None # pylint: disable=attribute-error
156-
and v.type.memory_space == smem # pylint: disable=attribute-error
156+
and utils.is_smem_ref(v)
157157
)
158158

159159

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,6 @@ def async_copy(
594594
index = ir.IndexType.get()
595595
i16 = ir.IntegerType.get_signless(16)
596596
i32 = ir.IntegerType.get_signless(32)
597-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
598597
src_ref_ty = ir.MemRefType(src_ref.type)
599598
dst_ref_ty = ir.MemRefType(dst_ref.type)
600599
element_type = src_ref_ty.element_type
@@ -609,13 +608,13 @@ def async_copy(
609608
if not isinstance(gmem_transform, tuple):
610609
gmem_transform = (gmem_transform,)
611610

612-
if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem:
611+
if src_ref_ty.memory_space is None and utils.is_smem_ref(dst_ref_ty):
613612
gmem_ref, smem_ref = src_ref, dst_ref
614613
if barrier is None:
615614
raise ValueError("Barriers are required for GMEM -> SMEM copies")
616615
if arrive is None:
617616
arrive = True # Arrive by default
618-
elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None:
617+
elif utils.is_smem_ref(src_ref_ty) and dst_ref_ty.memory_space is None:
619618
gmem_ref, smem_ref = dst_ref, src_ref
620619
if barrier is not None:
621620
raise ValueError("Barriers are unsupported for SMEM -> GMEM copies")

jax/experimental/mosaic/gpu/tcgen05.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact:
613613
ref_ty = ir.MemRefType(tmem_addr.type)
614614
if ref_ty.element_type != ir.IntegerType.get_signless(32):
615615
raise ValueError(f"tmem_addr must be an i32 memref, got: {ref_ty}")
616-
if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
616+
if not utils.is_smem_ref(ref_ty):
617617
raise ValueError(f"tmem_addr must be in shared memory, got: {ref_ty}")
618618
if math.prod(ref_ty.shape) != 1:
619619
raise ValueError(f"tmem_addr must contain a single element, got: {ref_ty}")
@@ -870,9 +870,8 @@ def from_alloc(
870870
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
871871
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
872872
addr_ref_ty = ir.MemRefType(tmem_addr_ref.type)
873-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
874-
if addr_ref_ty.memory_space != smem:
875-
raise ValueError(f"tmem_addr_ref must be in workgroup memory, got: {addr_ref_ty}")
873+
if not utils.is_smem_ref(addr_ref_ty):
874+
raise ValueError(f"tmem_addr_ref must be in shared memory, got: {addr_ref_ty}")
876875
if addr_ref_ty.element_type != i32:
877876
raise ValueError(f"tmem_addr_ref must be an i32 memref, got: {addr_ref_ty}")
878877
if math.prod(addr_ref_ty.shape) != 1:

jax/experimental/mosaic/gpu/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,3 +1548,20 @@ def is_known_divisible(value, divisor, max_depth=10) -> bool:
15481548
)
15491549

15501550
return False
1551+
1552+
1553+
def smem() -> ir.Attribute:
1554+
"""Returns the attribute for the SMEM memory space."""
1555+
return ir.Attribute.parse("#gpu.address_space<workgroup>")
1556+
1557+
1558+
def is_smem_ref(ref: ir.Value | ir.Type) -> bool:
1559+
"""Returns true if the input mem ref or memref type points to SMEM.
1560+
If the input is not at all of a memref type, raises a ValueError.
1561+
"""
1562+
if isinstance(ref, ir.Value):
1563+
ref = ref.type
1564+
if not ir.MemRefType.isinstance(ref):
1565+
raise ValueError(f"Expected a memref type but got {ref}")
1566+
ref = ir.MemRefType(ref)
1567+
return ref.memory_space is not None and ref.memory_space == smem()

tests/mosaic/gpu_dialect_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -967,8 +967,7 @@ def test_lowering_slice_smem_op(self):
967967
def body():
968968
nonlocal offset
969969
i32 = ir.IntegerType.get_signless(32)
970-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
971-
memref_ty = ir.MemRefType.get((4, 32), i32, memory_space=smem)
970+
memref_ty = ir.MemRefType.get((4, 32), i32, memory_space=mgpu_utils.smem())
972971
offset = arith.constant(i32, shift)
973972
op = mgpu.dialect.SliceSMEMOp(memref_ty, offset)
974973
op.attributes["out_transforms"] = ir.ArrayAttr.get([ir.ArrayAttr.get([])])
@@ -1051,8 +1050,7 @@ def body(vec1, vec2, ref):
10511050
)
10521051

10531052
with ir.InsertionPoint(self.module.body):
1054-
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
1055-
ref_ty = ir.MemRefType.get((4, 32), ir.BF16Type.get(), memory_space=smem)
1053+
ref_ty = ir.MemRefType.get((4, 32), ir.BF16Type.get(), memory_space=mgpu_utils.smem())
10561054
func.FuncOp.from_py_func(vec_ty, vec_ty, ref_ty)(body)
10571055

10581056
if omit_in_layouts:
@@ -1071,7 +1069,7 @@ def test_memref_transforms_with_transpose(self):
10711069
ty_in = ir.MemRefType.get(
10721070
(64, 128),
10731071
ir.BF16Type.get(),
1074-
memory_space=ir.Attribute.parse("#gpu.address_space<workgroup>"),
1072+
memory_space=mgpu_utils.smem(),
10751073
)
10761074
ref = memref.alloc(ty_in, [], [])
10771075

tests/mosaic/gpu_test.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
4040
from jax.experimental.mosaic.gpu import fragmented_array as fa
4141
from jax.experimental.mosaic.gpu import tcgen05
42+
from jax.experimental.mosaic.gpu import utils as mgpu_utils
4243
import jax.numpy as jnp
4344
import numpy as np
4445
try:
@@ -130,7 +131,6 @@ def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
130131
packing = 8 // bw
131132
if shape[-1] % packing:
132133
raise NotImplementedError
133-
workgroup_mem = ir.Attribute.parse("#gpu.address_space<workgroup>")
134134
shape = (*shape[:-1], shape[-1] // packing)
135135
contig_strides = get_contiguous_strides(shape)
136136
def bitcast(ref):
@@ -145,12 +145,7 @@ def bitcast(ref):
145145
ir.StridedLayoutAttr.get(0, new_strides),
146146
ref_ty.memory_space,
147147
)
148-
ptr_space = (
149-
3
150-
if ref_ty.memory_space is not None
151-
and ref_ty.memory_space == workgroup_mem
152-
else None
153-
)
148+
ptr_space = 3 if mgpu_utils.is_smem_ref(ref_ty) else None
154149
return ptr_as_memref(
155150
# NOTE: memref_ptr applies the offset in case there was any.
156151
memref_ptr(ref, memory_space=ptr_space),

0 commit comments

Comments
 (0)