Skip to content

Commit 8a614df

Browse files
committed
Lowering alignment attribute from memref.load/store to LLVM
1 parent 0019a9e commit 8a614df

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,8 +841,8 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
841841
adaptor.getMemref(),
842842
adaptor.getIndices(), kNoWrapFlags);
843843
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
844-
loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
845-
false, loadOp.getNontemporal());
844+
loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
845+
loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
846846
return success();
847847
}
848848
};
@@ -864,7 +864,8 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
864864
getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
865865
adaptor.getIndices(), kNoWrapFlags);
866866
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
867-
0, false, op.getNontemporal());
867+
op.getAlignment().value_or(0),
868+
false, op.getNontemporal());
868869
return success();
869870
}
870871
};

mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,15 @@ func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
148148

149149
// -----
150150

151+
// CHECK-LABEL: func @aligned_load(
152+
func.func @aligned_load(%static : memref<10x42xf32>, %i : index, %j : index) {
153+
// CHECK: llvm.load %{{.*}} {alignment = 16 : i64} : !llvm.ptr -> f32
154+
%0 = memref.load %static[%i, %j] { alignment = 16 } : memref<10x42xf32>
155+
return
156+
}
157+
158+
// -----
159+
151160
// CHECK-LABEL: func @zero_d_store
152161
func.func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
153162
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64)>
@@ -177,6 +186,16 @@ func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %va
177186

178187
// -----
179188

189+
// CHECK-LABEL: func @aligned_store
190+
func.func @aligned_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
191+
// CHECK: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} : f32, !llvm.ptr
192+
193+
memref.store %val, %static[%i, %j] { alignment = 16 } : memref<10x42xf32>
194+
return
195+
}
196+
197+
// -----
198+
180199
// CHECK-LABEL: func @static_memref_dim
181200
func.func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
182201
// CHECK: llvm.mlir.constant(42 : index) : i64

0 commit comments

Comments
 (0)