@@ -82,6 +82,14 @@ def wgmma_encode(x: int):
82
82
return result
83
83
84
84
85
+ def llvm_mul (x , y ):
86
+ return llvm .mul (x , y , overflow_flags = llvm .IntegerOverflowFlags .none )
87
+
88
+
89
+ def llvm_add (x , y ):
90
+ return llvm .add (x , y , overflow_flags = llvm .IntegerOverflowFlags .none )
91
+
92
+
85
93
def get_memref_base (memref_arg , memory_space = None ):
86
94
i64 = ir .IntegerType .get_signless (64 )
87
95
memref_ty = ir .MemRefType (memref_arg .type )
@@ -99,9 +107,9 @@ def get_memref_base(memref_arg, memory_space=None):
99
107
desc = builtin .UnrealizedConversionCastOp ([desc_ty ], [memref_arg ])
100
108
aligned_ptr = llvm .extractvalue (ptr_ty , desc , [1 ])
101
109
offset_elems = llvm .extractvalue (i64 , desc , [2 ])
102
- offset_bytes = llvm . mul (offset_elems , c (elem_bytewidth , i64 ))
110
+ offset_bytes = llvm_mul (offset_elems , c (elem_bytewidth , i64 ))
103
111
return llvm .inttoptr (
104
- ptr_ty , llvm . add (llvm .ptrtoint (i64 , aligned_ptr ), offset_bytes )
112
+ ptr_ty , llvm_add (llvm .ptrtoint (i64 , aligned_ptr ), offset_bytes )
105
113
)
106
114
107
115
@@ -246,14 +254,14 @@ def as_i32_reg(v):
246
254
a_args = [as_i32_reg (v ) for v in a_slice .registers .flat ]
247
255
else :
248
256
if i > 0 :
249
- a = llvm . add (
257
+ a = llvm_add (
250
258
a ,
251
259
llvm .ConstantOp (i64 , ir .IntegerAttr .get (i64 , a_k_stride >> 4 )),
252
260
)
253
261
a_args = [a ]
254
262
# Advance the B descriptor.
255
263
if i > 0 :
256
- b_descriptor = llvm . add (
264
+ b_descriptor = llvm_add (
257
265
b_descriptor ,
258
266
llvm .ConstantOp (i64 , ir .IntegerAttr .get (i64 , b_k_stride >> 4 )),
259
267
)
@@ -388,11 +396,11 @@ def wgmma(
388
396
if a_in_regs :
389
397
a_mk = a [mi * 64 : (mi + 1 ) * 64 , ki * kn_tile : (ki + 1 ) * kn_tile ]
390
398
else :
391
- a_mk = llvm . add (
399
+ a_mk = llvm_add (
392
400
a_desc_base ,
393
401
c (wgmma_encode (mi * a_m_byte_stride + ki * a_k_byte_stride ), i64 ),
394
402
)
395
- b_k = llvm . add (b_desc_base , c (wgmma_encode (ki * b_k_byte_stride ), i64 ))
403
+ b_k = llvm_add (b_desc_base , c (wgmma_encode (ki * b_k_byte_stride ), i64 ))
396
404
new_acc_regs [mi : mi + 1 ] = wgmma_m64k128B (
397
405
new_acc_regs [mi : mi + 1 ], a_mk , b_k , ** wgmma_params
398
406
)
0 commit comments