Skip to content

Commit b79f3b7

Browse files
apaszkejax authors
authored andcommitted
[Mosaic:GPU] Update lowering to match upstream changes in the LLVM dialect
LLVM integer arithmetic ops now explicitly require the overflow flags. PiperOrigin-RevId: 627020143
1 parent 83aff78 commit b79f3b7

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

jax/experimental/mosaic/gpu/wgmma.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ def wgmma_encode(x: int):
8282
return result
8383

8484

85+
def llvm_mul(x, y):
86+
return llvm.mul(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
87+
88+
89+
def llvm_add(x, y):
90+
return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
91+
92+
8593
def get_memref_base(memref_arg, memory_space=None):
8694
i64 = ir.IntegerType.get_signless(64)
8795
memref_ty = ir.MemRefType(memref_arg.type)
@@ -99,9 +107,9 @@ def get_memref_base(memref_arg, memory_space=None):
99107
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
100108
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
101109
offset_elems = llvm.extractvalue(i64, desc, [2])
102-
offset_bytes = llvm.mul(offset_elems, c(elem_bytewidth, i64))
110+
offset_bytes = llvm_mul(offset_elems, c(elem_bytewidth, i64))
103111
return llvm.inttoptr(
104-
ptr_ty, llvm.add(llvm.ptrtoint(i64, aligned_ptr), offset_bytes)
112+
ptr_ty, llvm_add(llvm.ptrtoint(i64, aligned_ptr), offset_bytes)
105113
)
106114

107115

@@ -246,14 +254,14 @@ def as_i32_reg(v):
246254
a_args = [as_i32_reg(v) for v in a_slice.registers.flat]
247255
else:
248256
if i > 0:
249-
a = llvm.add(
257+
a = llvm_add(
250258
a,
251259
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
252260
)
253261
a_args = [a]
254262
# Advance the B descriptor.
255263
if i > 0:
256-
b_descriptor = llvm.add(
264+
b_descriptor = llvm_add(
257265
b_descriptor,
258266
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
259267
)
@@ -388,11 +396,11 @@ def wgmma(
388396
if a_in_regs:
389397
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tile : (ki + 1) * kn_tile]
390398
else:
391-
a_mk = llvm.add(
399+
a_mk = llvm_add(
392400
a_desc_base,
393401
c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
394402
)
395-
b_k = llvm.add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64))
403+
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64))
396404
new_acc_regs[mi : mi + 1] = wgmma_m64k128B(
397405
new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
398406
)

0 commit comments

Comments
 (0)