-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][Conversion] Add convert-xevm-to-llvm pass. #147375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Artem Kroviakov artem.kroviakov@intel.com
@llvm/pr-subscribers-mlir-gpu Author: Sang Ik Lee (silee2) ChangesCo-authored-by: Artem Kroviakov artem.kroviakov@intel.com Patch is 38.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147375.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..8a5976e547169 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -80,6 +80,7 @@
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..929c3cf5a25ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1495,4 +1495,13 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
];
}
+//===----------------------------------------------------------------------===//
+// XeVMToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
+ let summary = "Convert XeVM to LLVM dialect";
+ let dependentDialects = ["xevm::XeVMDialect", ];
+}
+
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
new file mode 100644
index 0000000000000..b361af573afff
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
@@ -0,0 +1,29 @@
+//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+} // namespace mlir
+
+namespace mlir {
+#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
+
+void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index f356b91b1b6c0..47e2f554abb69 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
@@ -90,6 +91,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
gpu::registerConvertGpuToLLVMInterface(registry);
NVVM::registerConvertGpuToNVVMInterface(registry);
vector::registerConvertVectorToLLVMInterface(registry);
+ registerConvertXeVMToLLVMInterface(registry);
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..24a48993ad80c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -73,3 +73,4 @@ add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToSCF)
add_subdirectory(VectorToSPIRV)
add_subdirectory(VectorToXeGPU)
+add_subdirectory(XeVMToLLVM)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..4ac60d8d43472
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRXeVMToLLVM
+ XeVMToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRGPUDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRPass
+ MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
new file mode 100644
index 0000000000000..89407825fd656
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -0,0 +1,669 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "xevm-to-llvm"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace xevm;
+
+namespace {
+
+struct LLVMFuncAttributeOptions {
+ bool isConvergent = false;
+ bool isNoUnwind = false;
+ bool isWillReturn = false;
+ LLVM::MemoryEffectsAttr memEffectsAttr{};
+};
+static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
+ false, true, false, {}};
+static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
+ false, true, true, {}};
+static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
+ true, true, true, {}};
+
+std::string getTypeMangling(Type ty, bool isUnsigned = false) {
+ return TypeSwitch<Type, std::string>(ty)
+ .Case([isUnsigned](VectorType ty) -> std::string {
+ return "Dv" + std::to_string(ty.getNumElements()) + "_" +
+ getTypeMangling(ty.getElementType(), isUnsigned);
+ })
+ .Case([](Float16Type) -> std::string { return "Dh"; })
+ .Case([](Float32Type) -> std::string { return "f"; })
+ .Case([](Float64Type) -> std::string { return "d"; })
+ .Case([isUnsigned](IntegerType ty) -> std::string {
+ switch (ty.getWidth()) {
+ case 8:
+ return isUnsigned ? "h" : "c";
+ case 16:
+ return isUnsigned ? "t" : "s";
+ case 32:
+ return isUnsigned ? "j" : "i";
+ case 64:
+ return isUnsigned ? "m" : "l";
+ default:
+ llvm_unreachable("unhandled integer type");
+ }
+ });
+}
+
+std::string mangle(StringRef baseName, ArrayRef<Type> types,
+ ArrayRef<bool> isUnsigned = {}) {
+ assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
+ "Signedness info doesn't match");
+ std::string s;
+ llvm::raw_string_ostream os(s);
+ llvm::SmallDenseMap<Type, unsigned> substitutions;
+ os << "_Z" << baseName.size() << baseName;
+ for (auto [idx, type] : llvm::enumerate(types)) {
+ auto it = substitutions.find(type);
+ if (it != substitutions.end()) {
+ os << "S";
+ // First substitution is `S_`, second is `S0_`, and so on.
+ if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
+ os << firstIdx - 1;
+ os << "_";
+ } else {
+ if (!type.isIntOrFloat())
+ substitutions[type] = substitutions.size();
+ os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
+ }
+ }
+ return os.str();
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL1CacheControl(OpType op) {
+ int32_t control = 0;
+ if constexpr (isLoad) {
+ switch (*op.getCacheControl()) {
+ case LoadCacheControl::L1UC_L2UC_L3UC:
+ case LoadCacheControl::L1UC_L2UC_L3C:
+ case LoadCacheControl::L1UC_L2C_L3UC:
+ case LoadCacheControl::L1UC_L2C_L3C:
+ control = 1;
+ break;
+ case LoadCacheControl::L1C_L2UC_L3UC:
+ case LoadCacheControl::L1C_L2UC_L3C:
+ case LoadCacheControl::L1C_L2C_L3UC:
+ case LoadCacheControl::L1C_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1S_L2UC_L3UC:
+ case LoadCacheControl::L1S_L2UC_L3C:
+ case LoadCacheControl::L1S_L2C_L3UC:
+ case LoadCacheControl::L1S_L2C_L3C:
+ control = 3;
+ break;
+ case LoadCacheControl::INVALIDATE_READ:
+ control = 4;
+ break;
+ default:
+ break;
+ }
+ } else {
+ switch (*op.getCacheControl()) {
+ case StoreCacheControl::L1UC_L2UC_L3UC:
+ case StoreCacheControl::L1UC_L2UC_L3WB:
+ case StoreCacheControl::L1UC_L2WB_L3UC:
+ case StoreCacheControl::L1UC_L2WB_L3WB:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WT_L2UC_L3UC:
+ case StoreCacheControl::L1WT_L2UC_L3WB:
+ case StoreCacheControl::L1WT_L2WB_L3UC:
+ case StoreCacheControl::L1WT_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1S_L2UC_L3UC:
+ case StoreCacheControl::L1S_L2UC_L3WB:
+ case StoreCacheControl::L1S_L2WB_L3UC:
+ case StoreCacheControl::L1S_L2WB_L3WB:
+ control = 3;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ case StoreCacheControl::L1WB_L2UC_L3WB:
+ control = 4;
+ break;
+ default:
+ break;
+ }
+ }
+ return control;
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL3CacheControl(OpType op) {
+ int32_t control = 0;
+ if constexpr (isLoad) {
+ switch (*op.getCacheControl()) {
+ case LoadCacheControl::L1UC_L2UC_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1UC_L2UC_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1UC_L2C_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1UC_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1C_L2UC_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1C_L2UC_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1C_L2C_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1C_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1S_L2UC_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1S_L2UC_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1S_L2C_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1S_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::INVALIDATE_READ:
+ control = 4;
+ break;
+ default:
+ break;
+ }
+ } else {
+ switch (*op.getCacheControl()) {
+ case StoreCacheControl::L1UC_L2UC_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1UC_L2UC_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1UC_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1UC_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1WT_L2UC_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WT_L2UC_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1WT_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WT_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1S_L2UC_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1S_L2UC_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1S_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1S_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3WB:
+ control = 2;
+ break;
+ default:
+ break;
+ }
+ }
+ return control;
+}
+
+template <bool isLoad, typename OpType>
+static std::optional<ArrayAttr>
+getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
+ if constexpr (isLoad) {
+ if (!op.getCacheControl())
+ return {};
+ } else {
+ if (!op.getCacheControl())
+ return {};
+ }
+ constexpr int32_t decorationCacheControlArity{4};
+ constexpr int32_t loadCacheControlKey{6442};
+ constexpr int32_t storeCacheControlKey{6443};
+ const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
+ SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
+ controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
+ SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
+ controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
+ auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
+ auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
+
+ SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
+ return rewriter.getArrayAttr(combinedAttrs);
+}
+
+static LLVM::CallOp createDeviceFunctionCall(
+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
+ ArrayRef<Type> argTypes, ArrayRef<Value> args,
+ mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
+ LLVMFuncAttributeOptions funcAttributeOptions) {
+ auto moduleOp = rewriter.getBlock()
+ ->getParentOp()
+ ->getParentWithTrait<OpTrait::SymbolTable>();
+ assert(moduleOp && "Expecting module");
+ MLIRContext *ctx = rewriter.getContext();
+ Location loc = UnknownLoc::get(ctx);
+
+ auto funcOpRes =
+ LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
+ assert(!failed(funcOpRes));
+ LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+ funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+ funcOp.setConvergent(funcAttributeOptions.isConvergent);
+ funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
+ funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
+
+ if (funcAttributeOptions.memEffectsAttr)
+ funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
+
+ for (auto [idx, attrName] : paramAttrs)
+ funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
+
+ auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+ callOp->setAttrs(funcOp->getAttrs());
+
+ return callOp;
+}
+
+class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!op.getC()) {
+ return rewriter.notifyMatchFailure(op, "OCL requires C operand");
+ }
+ constexpr uint32_t bitWidthPackedA{16};
+ constexpr uint32_t bitWidthPackedB{32};
+ auto loc = op.getLoc();
+
+ auto castIfNeeded = [&](Value val, Type packedType) -> Value {
+ VectorType origTy = cast<VectorType>(val.getType());
+ const uint32_t vecBitSize =
+ origTy.getNumElements() *
+ origTy.getElementType().getIntOrFloatBitWidth();
+ VectorType newTy = VectorType::get(
+ vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
+ if (origTy != newTy)
+ val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+ return val;
+ };
+
+ Value a = op.getA();
+ Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
+ ? cast<Type>(rewriter.getF32Type())
+ : rewriter.getIntegerType(bitWidthPackedA);
+ a = castIfNeeded(a, packedAType);
+
+ Value b = op.getB();
+ Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
+ ? cast<Type>(rewriter.getF32Type())
+ : rewriter.getIntegerType(bitWidthPackedB);
+ b = castIfNeeded(b, packedBType);
+
+ Value c = op.getC();
+ VectorType cOrigTy = cast<VectorType>(c.getType());
+ assert(cOrigTy == op->getResultTypes()[0] &&
+ "Accumulator and result type mismatch");
+ // OCL builtins encode bfloat16 as int16
+ VectorType cTy =
+ cOrigTy.getElementType().isBF16()
+ ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
+ : cOrigTy;
+ if (cOrigTy != cTy)
+ c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
+
+ constexpr int32_t systolicDepth{8};
+ std::string fnName =
+ llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
+ stringifyElemType(op.getTypes().getA()).str(),
+ stringifyElemType(op.getTypes().getB()).str(),
+ systolicDepth *
+ getNumOperandsPerDword(op.getTypes().getA()))
+ .str();
+ SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
+ fnName = mangle(fnName, argTypes);
+ SmallVector<Value> args{a, b, c};
+
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/LLVM::ModRefInfo::NoModRef,
+ /*argMem=*/LLVM::ModRefInfo::NoModRef,
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+ funcAttrs.memEffectsAttr = memAttr;
+ Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
+ args, {}, funcAttrs)
+ ->getResult(0);
+
+ if (cOrigTy != cTy)
+ result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+private:
+ static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+ switch (pTy) {
+ case xevm::ElemType::TF32:
+ return 1;
+ case xevm::ElemType::BF16:
+ case xevm::ElemType::F16:
+ return 2;
+ case xevm::ElemType::U8:
+ case xevm::ElemType::S8:
+ return 4;
+ default:
+ llvm_unreachable("unsupported xevm::ElemType");
+ }
+ }
+};
+
+class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
+ Value one = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));
+ SmallVector<Value> args{op.getPtr(), one};
+ SmallVector<Type> argTypes;
+ for (auto arg : args)
+ argTypes.push_back(arg.getType());
+ auto funcAttr = noUnwindAttrs;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/LLVM::ModRefInfo::NoModRef,
+ /*argMem=*/LLVM::ModRefInfo::Ref,
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ funcAttr.memEffectsAttr = memAttr;
+
+ LLVM::CallOp call = createDeviceFunctionCall(
+ rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
+ argTypes, args, {}, funcAttr);
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata<true>(rewriter, op))
+ call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ const std::string fnName{"atomic_work_item_fence"};
+ int memScope, addrSpace;
+ switch (op.getAddrspace()) {
+ case xevm::AddrSpace::SHARED:
+ addrSpace = 1; // CLK_LOCAL_MEM_FENCE
+ break;
+ case xevm::AddrSpace::GLOBAL:
+ addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
+ break;
+ ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Sang Ik Lee (silee2) ChangesCo-authored-by: Artem Kroviakov artem.kroviakov@intel.com Patch is 38.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147375.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..8a5976e547169 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -80,6 +80,7 @@
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..929c3cf5a25ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1495,4 +1495,13 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
];
}
+//===----------------------------------------------------------------------===//
+// XeVMToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
+ let summary = "Convert XeVM to LLVM dialect";
+ let dependentDialects = ["xevm::XeVMDialect", ];
+}
+
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
new file mode 100644
index 0000000000000..b361af573afff
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
@@ -0,0 +1,29 @@
+//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+} // namespace mlir
+
+namespace mlir {
+#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
+
+void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index f356b91b1b6c0..47e2f554abb69 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
@@ -90,6 +91,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
gpu::registerConvertGpuToLLVMInterface(registry);
NVVM::registerConvertGpuToNVVMInterface(registry);
vector::registerConvertVectorToLLVMInterface(registry);
+ registerConvertXeVMToLLVMInterface(registry);
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..24a48993ad80c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -73,3 +73,4 @@ add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToSCF)
add_subdirectory(VectorToSPIRV)
add_subdirectory(VectorToXeGPU)
+add_subdirectory(XeVMToLLVM)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..4ac60d8d43472
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRXeVMToLLVM
+ XeVMToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRGPUDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRPass
+ MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
new file mode 100644
index 0000000000000..89407825fd656
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -0,0 +1,669 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "xevm-to-llvm"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace xevm;
+
+namespace {
+
+struct LLVMFuncAttributeOptions {
+ bool isConvergent = false;
+ bool isNoUnwind = false;
+ bool isWillReturn = false;
+ LLVM::MemoryEffectsAttr memEffectsAttr{};
+};
+static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
+ false, true, false, {}};
+static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
+ false, true, true, {}};
+static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
+ true, true, true, {}};
+
+std::string getTypeMangling(Type ty, bool isUnsigned = false) {
+ return TypeSwitch<Type, std::string>(ty)
+ .Case([isUnsigned](VectorType ty) -> std::string {
+ return "Dv" + std::to_string(ty.getNumElements()) + "_" +
+ getTypeMangling(ty.getElementType(), isUnsigned);
+ })
+ .Case([](Float16Type) -> std::string { return "Dh"; })
+ .Case([](Float32Type) -> std::string { return "f"; })
+ .Case([](Float64Type) -> std::string { return "d"; })
+ .Case([isUnsigned](IntegerType ty) -> std::string {
+ switch (ty.getWidth()) {
+ case 8:
+ return isUnsigned ? "h" : "c";
+ case 16:
+ return isUnsigned ? "t" : "s";
+ case 32:
+ return isUnsigned ? "j" : "i";
+ case 64:
+ return isUnsigned ? "m" : "l";
+ default:
+ llvm_unreachable("unhandled integer type");
+ }
+ });
+}
+
+std::string mangle(StringRef baseName, ArrayRef<Type> types,
+ ArrayRef<bool> isUnsigned = {}) {
+ assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
+ "Signedness info doesn't match");
+ std::string s;
+ llvm::raw_string_ostream os(s);
+ llvm::SmallDenseMap<Type, unsigned> substitutions;
+ os << "_Z" << baseName.size() << baseName;
+ for (auto [idx, type] : llvm::enumerate(types)) {
+ auto it = substitutions.find(type);
+ if (it != substitutions.end()) {
+ os << "S";
+ // First substitution is `S_`, second is `S0_`, and so on.
+ if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
+ os << firstIdx - 1;
+ os << "_";
+ } else {
+ if (!type.isIntOrFloat())
+ substitutions[type] = substitutions.size();
+ os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
+ }
+ }
+ return os.str();
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL1CacheControl(OpType op) {
+ int32_t control = 0;
+ if constexpr (isLoad) {
+ switch (*op.getCacheControl()) {
+ case LoadCacheControl::L1UC_L2UC_L3UC:
+ case LoadCacheControl::L1UC_L2UC_L3C:
+ case LoadCacheControl::L1UC_L2C_L3UC:
+ case LoadCacheControl::L1UC_L2C_L3C:
+ control = 1;
+ break;
+ case LoadCacheControl::L1C_L2UC_L3UC:
+ case LoadCacheControl::L1C_L2UC_L3C:
+ case LoadCacheControl::L1C_L2C_L3UC:
+ case LoadCacheControl::L1C_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1S_L2UC_L3UC:
+ case LoadCacheControl::L1S_L2UC_L3C:
+ case LoadCacheControl::L1S_L2C_L3UC:
+ case LoadCacheControl::L1S_L2C_L3C:
+ control = 3;
+ break;
+ case LoadCacheControl::INVALIDATE_READ:
+ control = 4;
+ break;
+ default:
+ break;
+ }
+ } else {
+ switch (*op.getCacheControl()) {
+ case StoreCacheControl::L1UC_L2UC_L3UC:
+ case StoreCacheControl::L1UC_L2UC_L3WB:
+ case StoreCacheControl::L1UC_L2WB_L3UC:
+ case StoreCacheControl::L1UC_L2WB_L3WB:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WT_L2UC_L3UC:
+ case StoreCacheControl::L1WT_L2UC_L3WB:
+ case StoreCacheControl::L1WT_L2WB_L3UC:
+ case StoreCacheControl::L1WT_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1S_L2UC_L3UC:
+ case StoreCacheControl::L1S_L2UC_L3WB:
+ case StoreCacheControl::L1S_L2WB_L3UC:
+ case StoreCacheControl::L1S_L2WB_L3WB:
+ control = 3;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ case StoreCacheControl::L1WB_L2UC_L3WB:
+ control = 4;
+ break;
+ default:
+ break;
+ }
+ }
+ return control;
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL3CacheControl(OpType op) {
+ int32_t control = 0;
+ if constexpr (isLoad) {
+ switch (*op.getCacheControl()) {
+ case LoadCacheControl::L1UC_L2UC_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1UC_L2UC_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1UC_L2C_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1UC_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1C_L2UC_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1C_L2UC_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1C_L2C_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1C_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1S_L2UC_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1S_L2UC_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1S_L2C_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1S_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::INVALIDATE_READ:
+ control = 4;
+ break;
+ default:
+ break;
+ }
+ } else {
+ switch (*op.getCacheControl()) {
+ case StoreCacheControl::L1UC_L2UC_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1UC_L2UC_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1UC_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1UC_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1WT_L2UC_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WT_L2UC_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1WT_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WT_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1S_L2UC_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1S_L2UC_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1S_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1S_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3WB:
+ control = 2;
+ break;
+ default:
+ break;
+ }
+ }
+ return control;
+}
+
+template <bool isLoad, typename OpType>
+static std::optional<ArrayAttr>
+getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
+ if constexpr (isLoad) {
+ if (!op.getCacheControl())
+ return {};
+ } else {
+ if (!op.getCacheControl())
+ return {};
+ }
+ constexpr int32_t decorationCacheControlArity{4};
+ constexpr int32_t loadCacheControlKey{6442};
+ constexpr int32_t storeCacheControlKey{6443};
+ const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
+ SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
+ controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
+ SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
+ controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
+ auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
+ auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
+
+ SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
+ return rewriter.getArrayAttr(combinedAttrs);
+}
+
+static LLVM::CallOp createDeviceFunctionCall(
+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
+ ArrayRef<Type> argTypes, ArrayRef<Value> args,
+ mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
+ LLVMFuncAttributeOptions funcAttributeOptions) {
+ auto moduleOp = rewriter.getBlock()
+ ->getParentOp()
+ ->getParentWithTrait<OpTrait::SymbolTable>();
+ assert(moduleOp && "Expecting module");
+ MLIRContext *ctx = rewriter.getContext();
+ Location loc = UnknownLoc::get(ctx);
+
+ auto funcOpRes =
+ LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
+ assert(!failed(funcOpRes));
+ LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+ funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+ funcOp.setConvergent(funcAttributeOptions.isConvergent);
+ funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
+ funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
+
+ if (funcAttributeOptions.memEffectsAttr)
+ funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
+
+ for (auto [idx, attrName] : paramAttrs)
+ funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
+
+ auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+ callOp->setAttrs(funcOp->getAttrs());
+
+ return callOp;
+}
+
+class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!op.getC()) {
+ return rewriter.notifyMatchFailure(op, "OCL requires C operand");
+ }
+ constexpr uint32_t bitWidthPackedA{16};
+ constexpr uint32_t bitWidthPackedB{32};
+ auto loc = op.getLoc();
+
+ auto castIfNeeded = [&](Value val, Type packedType) -> Value {
+ VectorType origTy = cast<VectorType>(val.getType());
+ const uint32_t vecBitSize =
+ origTy.getNumElements() *
+ origTy.getElementType().getIntOrFloatBitWidth();
+ VectorType newTy = VectorType::get(
+ vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
+ if (origTy != newTy)
+ val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+ return val;
+ };
+
+ Value a = op.getA();
+ Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
+ ? cast<Type>(rewriter.getF32Type())
+ : rewriter.getIntegerType(bitWidthPackedA);
+ a = castIfNeeded(a, packedAType);
+
+ Value b = op.getB();
+ Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
+ ? cast<Type>(rewriter.getF32Type())
+ : rewriter.getIntegerType(bitWidthPackedB);
+ b = castIfNeeded(b, packedBType);
+
+ Value c = op.getC();
+ VectorType cOrigTy = cast<VectorType>(c.getType());
+ assert(cOrigTy == op->getResultTypes()[0] &&
+ "Accumulator and result type mismatch");
+ // OCL builtins encode bfloat16 as int16
+ VectorType cTy =
+ cOrigTy.getElementType().isBF16()
+ ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
+ : cOrigTy;
+ if (cOrigTy != cTy)
+ c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
+
+ constexpr int32_t systolicDepth{8};
+ std::string fnName =
+ llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
+ stringifyElemType(op.getTypes().getA()).str(),
+ stringifyElemType(op.getTypes().getB()).str(),
+ systolicDepth *
+ getNumOperandsPerDword(op.getTypes().getA()))
+ .str();
+ SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
+ fnName = mangle(fnName, argTypes);
+ SmallVector<Value> args{a, b, c};
+
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/LLVM::ModRefInfo::NoModRef,
+ /*argMem=*/LLVM::ModRefInfo::NoModRef,
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+ funcAttrs.memEffectsAttr = memAttr;
+ Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
+ args, {}, funcAttrs)
+ ->getResult(0);
+
+ if (cOrigTy != cTy)
+ result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+private:
+ static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+ switch (pTy) {
+ case xevm::ElemType::TF32:
+ return 1;
+ case xevm::ElemType::BF16:
+ case xevm::ElemType::F16:
+ return 2;
+ case xevm::ElemType::U8:
+ case xevm::ElemType::S8:
+ return 4;
+ default:
+ llvm_unreachable("unsupported xevm::ElemType");
+ }
+ }
+};
+
+class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
+ Value one = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));
+ SmallVector<Value> args{op.getPtr(), one};
+ SmallVector<Type> argTypes;
+ for (auto arg : args)
+ argTypes.push_back(arg.getType());
+ auto funcAttr = noUnwindAttrs;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/LLVM::ModRefInfo::NoModRef,
+ /*argMem=*/LLVM::ModRefInfo::Ref,
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ funcAttr.memEffectsAttr = memAttr;
+
+ LLVM::CallOp call = createDeviceFunctionCall(
+ rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
+ argTypes, args, {}, funcAttr);
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata<true>(rewriter, op))
+ call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ const std::string fnName{"atomic_work_item_fence"};
+ int memScope, addrSpace;
+ switch (op.getAddrspace()) {
+ case xevm::AddrSpace::SHARED:
+ addrSpace = 1; // CLK_LOCAL_MEM_FENCE
+ break;
+ case xevm::AddrSpace::GLOBAL:
+ addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
+ break;
+ ...
[truncated]
|
also cc @akroviakov |
template <bool isLoad, typename OpType> | ||
static std::optional<ArrayAttr> | ||
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { | ||
if constexpr (isLoad) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: both branches seem to be the same. Could one reduce this check to
if (!op.getCacheControl())
return {};
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Didn't realize they are the same.
Optimized as suggested.
Value memScopeConst = rewriter.create<LLVM::ConstantOp>( | ||
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(memScope)); | ||
Value addrSpaceConst = rewriter.create<LLVM::ConstantOp>( | ||
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(addrSpace)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: IntegerAttr should be avoidable for LLVM::ConstantOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this not in the translation?
This is neither lowering to built-in LLVM intrinsic (they are proper intrinsics but not standalone LLVM ops) nor using any automatically generated conversions. Conversion to llvm seems like suitable abstraction. Similarly to how |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good from perspective of infrastructure and overall flow
Minor comments about the lowering logic itself
I'll leave in-depth review of that to @Jianhui-Li
private: | ||
static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { | ||
switch (pTy) { | ||
case xevm::ElemType::TF32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why no f32?
It seems valid looking at the op definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f32
is valid in op definition for C
in the formula for MMA below.
D= A*B+C
A
and B
does not support f32
getNumOperandsPerDword
is a function for types supported by A
and B
if sub dword (or 32bit) types need to be packed/bundled into a dword. As the name implies, it returns how many operands of pTy
is packed per dword.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A
andB
does not supportf32
In this case, the op definition has to be further constrained. At the moment, xevm.mma
is happy to take f32
for all its operands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad.
MMA op can have mismatch in
- operand element type and
- types declared in attribute
For example,
%d = xevm.mma %a, %b, %c { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> }
: (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
i16
and i32
are packed type.
f32
can appear in A
or B
if the true type or declared type in attribute is tf32
packed type for tf32
is f32
In short, f32
is a valid element type for operand A
or B
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My larger point is that, for example, this op is valid today:
func.func @mma_f32(%c: vector<8xf32>, %a: vector<8xf32>,
%b: vector<8xf32>) -> vector<8xf32> {
%d = xevm.mma %a, %b, %c
{ shape=<m=8, n=16, k=16>,
types=<d=f32, a=f32, b=f32, c=f32> }
: (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
return %d : vector<8xf32>
}
Ideally, unsupported combinations would be caught by op verifier as it's nice to have that as an invariant.
If there's a potential use case for such combination (some rewrites, canonicalization etc. directly on xevm
ops), then please at least add extra checks to this pass so it doesn't blow up on an assertion.
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) | ||
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>, | ||
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, | ||
// CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems from the code that the cache hint is attached as attribute to the call. Is this being checked ?
if (std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata<true>(rewriter, op))
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is.
But not part of this PR. Metadata is consumed during translation.
That will be done in a follow up PR but I can provide a link to code in the POC branch consumes the metadata.
https://github.com/silee2/llvm-project/blob/xevmWorkspace/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp#L50-L72
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16> | ||
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y | ||
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, | ||
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add more tests to cover a few combination of these parameters, say the following:
- load 32x16xi16, with v_blocks =2
- load 16x16xi32, transpose=true
- load 16x16xi16, pack_register=true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added v_blocks, transpose and pack_register and transpose test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great. Can you also change the shape size as described above also?
auto funcAttrs = convergentNoUnwindWillReturnAttrs; | ||
funcAttrs.memEffectsAttr = memAttr; | ||
Value result = | ||
createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, args, {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider rename cTy to dTy since this parameter specify return type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, there are a few warning in CI. Please address them before merging.
private: | ||
static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { | ||
switch (pTy) { | ||
case xevm::ElemType::TF32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My larger point is that, for example, this op is valid today:
func.func @mma_f32(%c: vector<8xf32>, %a: vector<8xf32>,
%b: vector<8xf32>) -> vector<8xf32> {
%d = xevm.mma %a, %b, %c
{ shape=<m=8, n=16, k=16>,
types=<d=f32, a=f32, b=f32, c=f32> }
: (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
return %d : vector<8xf32>
}
Ideally, unsupported combinations would be caught by op verifier as it's nice to have that as an invariant.
If there's a potential use case for such combination (some rewrites, canonicalization etc. directly on xevm
ops), then please at least add extra checks to this pass so it doesn't blow up on an assertion.
…values [-Wcovered-switch-default]
Valid point. I'll keep that in mind. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Although XeVM is an LLVM extension dialect,
SPIR-V backend relies on function calls instead of defining LLVM intrinsics to represent SPIR-V instructions.
convert-xevm-to-llvm pass lowers xevm ops to function declarations and calls using the above naming convention.
In the future, most part of the pass should be replaced with llvmBuilder and handled as part of translation to LLVM instead.
Co-authored-by: Artem Kroviakov artem.kroviakov@intel.com