Skip to content

[MLIR][Conversion] Add convert-xevm-to-llvm pass. #147375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

silee2
Copy link
Contributor

@silee2 silee2 commented Jul 7, 2025

Although XeVM is an LLVM extension dialect,
SPIR-V backend relies on function calls instead of defining LLVM intrinsics to represent SPIR-V instructions.
convert-xevm-to-llvm pass lowers xevm ops to function declarations and calls using the above naming convention.
In the future, most part of the pass should be replaced with llvmBuilder and handled as part of translation to LLVM instead.

Co-authored-by: Artem Kroviakov artem.kroviakov@intel.com

Co-authored-by: Artem Kroviakov artem.kroviakov@intel.com
@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Sang Ik Lee (silee2)

Changes

Co-authored-by: Artem Kroviakov artem.kroviakov@intel.com


Patch is 38.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147375.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+9)
  • (added) mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h (+29)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt (+21)
  • (added) mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp (+669)
  • (added) mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir (+83)
  • (modified) mlir/test/lib/Dialect/GPU/CMakeLists.txt (+1)
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..8a5976e547169 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -80,6 +80,7 @@
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
 
 namespace mlir {
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..929c3cf5a25ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1495,4 +1495,13 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// XeVMToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
+  let summary = "Convert XeVM to LLVM dialect";
+  let dependentDialects = ["xevm::XeVMDialect", ];
+}
+
 #endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
new file mode 100644
index 0000000000000..b361af573afff
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
@@ -0,0 +1,29 @@
+//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+} // namespace mlir
+
+namespace mlir {
+#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
+
+void registerConvertXeVMToLLVMInterface(DialectRegistry &registry);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index f356b91b1b6c0..47e2f554abb69 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
@@ -90,6 +91,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
+  registerConvertXeVMToLLVMInterface(registry);
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..24a48993ad80c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -73,3 +73,4 @@ add_subdirectory(VectorToLLVM)
 add_subdirectory(VectorToSCF)
 add_subdirectory(VectorToSPIRV)
 add_subdirectory(VectorToXeGPU)
+add_subdirectory(XeVMToLLVM)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..4ac60d8d43472
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRXeVMToLLVM
+  XeVMToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRGPUDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRXeVMDialect
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
new file mode 100644
index 0000000000000..89407825fd656
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -0,0 +1,669 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "xevm-to-llvm"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace xevm;
+
+namespace {
+
+struct LLVMFuncAttributeOptions {
+  bool isConvergent = false;
+  bool isNoUnwind = false;
+  bool isWillReturn = false;
+  LLVM::MemoryEffectsAttr memEffectsAttr{};
+};
+static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
+    false, true, false, {}};
+static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
+    false, true, true, {}};
+static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
+    true, true, true, {}};
+
+std::string getTypeMangling(Type ty, bool isUnsigned = false) {
+  return TypeSwitch<Type, std::string>(ty)
+      .Case([isUnsigned](VectorType ty) -> std::string {
+        return "Dv" + std::to_string(ty.getNumElements()) + "_" +
+               getTypeMangling(ty.getElementType(), isUnsigned);
+      })
+      .Case([](Float16Type) -> std::string { return "Dh"; })
+      .Case([](Float32Type) -> std::string { return "f"; })
+      .Case([](Float64Type) -> std::string { return "d"; })
+      .Case([isUnsigned](IntegerType ty) -> std::string {
+        switch (ty.getWidth()) {
+        case 8:
+          return isUnsigned ? "h" : "c";
+        case 16:
+          return isUnsigned ? "t" : "s";
+        case 32:
+          return isUnsigned ? "j" : "i";
+        case 64:
+          return isUnsigned ? "m" : "l";
+        default:
+          llvm_unreachable("unhandled integer type");
+        }
+      });
+}
+
+std::string mangle(StringRef baseName, ArrayRef<Type> types,
+                   ArrayRef<bool> isUnsigned = {}) {
+  assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
+         "Signedness info doesn't match");
+  std::string s;
+  llvm::raw_string_ostream os(s);
+  llvm::SmallDenseMap<Type, unsigned> substitutions;
+  os << "_Z" << baseName.size() << baseName;
+  for (auto [idx, type] : llvm::enumerate(types)) {
+    auto it = substitutions.find(type);
+    if (it != substitutions.end()) {
+      os << "S";
+      // First substitution is `S_`, second is `S0_`, and so on.
+      if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
+        os << firstIdx - 1;
+      os << "_";
+    } else {
+      if (!type.isIntOrFloat())
+        substitutions[type] = substitutions.size();
+      os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
+    }
+  }
+  return os.str();
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL1CacheControl(OpType op) {
+  int32_t control = 0;
+  if constexpr (isLoad) {
+    switch (*op.getCacheControl()) {
+    case LoadCacheControl::L1UC_L2UC_L3UC:
+    case LoadCacheControl::L1UC_L2UC_L3C:
+    case LoadCacheControl::L1UC_L2C_L3UC:
+    case LoadCacheControl::L1UC_L2C_L3C:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3UC:
+    case LoadCacheControl::L1C_L2UC_L3C:
+    case LoadCacheControl::L1C_L2C_L3UC:
+    case LoadCacheControl::L1C_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3UC:
+    case LoadCacheControl::L1S_L2UC_L3C:
+    case LoadCacheControl::L1S_L2C_L3UC:
+    case LoadCacheControl::L1S_L2C_L3C:
+      control = 3;
+      break;
+    case LoadCacheControl::INVALIDATE_READ:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  } else {
+    switch (*op.getCacheControl()) {
+    case StoreCacheControl::L1UC_L2UC_L3UC:
+    case StoreCacheControl::L1UC_L2UC_L3WB:
+    case StoreCacheControl::L1UC_L2WB_L3UC:
+    case StoreCacheControl::L1UC_L2WB_L3WB:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3UC:
+    case StoreCacheControl::L1WT_L2UC_L3WB:
+    case StoreCacheControl::L1WT_L2WB_L3UC:
+    case StoreCacheControl::L1WT_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3UC:
+    case StoreCacheControl::L1S_L2UC_L3WB:
+    case StoreCacheControl::L1S_L2WB_L3UC:
+    case StoreCacheControl::L1S_L2WB_L3WB:
+      control = 3;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3UC:
+    case StoreCacheControl::L1WB_L2WB_L3UC:
+    case StoreCacheControl::L1WB_L2UC_L3WB:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  }
+  return control;
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL3CacheControl(OpType op) {
+  int32_t control = 0;
+  if constexpr (isLoad) {
+    switch (*op.getCacheControl()) {
+    case LoadCacheControl::L1UC_L2UC_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1UC_L2UC_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1UC_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1UC_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1C_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1S_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::INVALIDATE_READ:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  } else {
+    switch (*op.getCacheControl()) {
+    case StoreCacheControl::L1UC_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1UC_L2UC_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1UC_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1UC_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1WT_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1S_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WB_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3WB:
+      control = 2;
+      break;
+    default:
+      break;
+    }
+  }
+  return control;
+}
+
+template <bool isLoad, typename OpType>
+static std::optional<ArrayAttr>
+getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
+  if constexpr (isLoad) {
+    if (!op.getCacheControl())
+      return {};
+  } else {
+    if (!op.getCacheControl())
+      return {};
+  }
+  constexpr int32_t decorationCacheControlArity{4};
+  constexpr int32_t loadCacheControlKey{6442};
+  constexpr int32_t storeCacheControlKey{6443};
+  const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
+  SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
+      controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
+  SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
+      controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
+  auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
+  auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
+
+  SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
+  return rewriter.getArrayAttr(combinedAttrs);
+}
+
+static LLVM::CallOp createDeviceFunctionCall(
+    ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
+    ArrayRef<Type> argTypes, ArrayRef<Value> args,
+    mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
+    LLVMFuncAttributeOptions funcAttributeOptions) {
+  auto moduleOp = rewriter.getBlock()
+                      ->getParentOp()
+                      ->getParentWithTrait<OpTrait::SymbolTable>();
+  assert(moduleOp && "Expecting module");
+  MLIRContext *ctx = rewriter.getContext();
+  Location loc = UnknownLoc::get(ctx);
+
+  auto funcOpRes =
+      LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
+  assert(!failed(funcOpRes));
+  LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+  funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+  funcOp.setConvergent(funcAttributeOptions.isConvergent);
+  funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
+  funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
+
+  if (funcAttributeOptions.memEffectsAttr)
+    funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
+
+  for (auto [idx, attrName] : paramAttrs)
+    funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
+
+  auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+  callOp->setAttrs(funcOp->getAttrs());
+
+  return callOp;
+}
+
+class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op.getC()) {
+      return rewriter.notifyMatchFailure(op, "OCL requires C operand");
+    }
+    constexpr uint32_t bitWidthPackedA{16};
+    constexpr uint32_t bitWidthPackedB{32};
+    auto loc = op.getLoc();
+
+    auto castIfNeeded = [&](Value val, Type packedType) -> Value {
+      VectorType origTy = cast<VectorType>(val.getType());
+      const uint32_t vecBitSize =
+          origTy.getNumElements() *
+          origTy.getElementType().getIntOrFloatBitWidth();
+      VectorType newTy = VectorType::get(
+          vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
+      if (origTy != newTy)
+        val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+      return val;
+    };
+
+    Value a = op.getA();
+    Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
+                           ? cast<Type>(rewriter.getF32Type())
+                           : rewriter.getIntegerType(bitWidthPackedA);
+    a = castIfNeeded(a, packedAType);
+
+    Value b = op.getB();
+    Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
+                           ? cast<Type>(rewriter.getF32Type())
+                           : rewriter.getIntegerType(bitWidthPackedB);
+    b = castIfNeeded(b, packedBType);
+
+    Value c = op.getC();
+    VectorType cOrigTy = cast<VectorType>(c.getType());
+    assert(cOrigTy == op->getResultTypes()[0] &&
+           "Accumulator and result type mismatch");
+    // OCL builtins encode bfloat16 as int16
+    VectorType cTy =
+        cOrigTy.getElementType().isBF16()
+            ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
+            : cOrigTy;
+    if (cOrigTy != cTy)
+      c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
+
+    constexpr int32_t systolicDepth{8};
+    std::string fnName =
+        llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
+                      stringifyElemType(op.getTypes().getA()).str(),
+                      stringifyElemType(op.getTypes().getB()).str(),
+                      systolicDepth *
+                          getNumOperandsPerDword(op.getTypes().getA()))
+            .str();
+    SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
+    fnName = mangle(fnName, argTypes);
+    SmallVector<Value> args{a, b, c};
+
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/LLVM::ModRefInfo::NoModRef,
+        /*argMem=*/LLVM::ModRefInfo::NoModRef,
+        /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+    auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+    funcAttrs.memEffectsAttr = memAttr;
+    Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
+                                            args, {}, funcAttrs)
+                       ->getResult(0);
+
+    if (cOrigTy != cTy)
+      result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+    switch (pTy) {
+    case xevm::ElemType::TF32:
+      return 1;
+    case xevm::ElemType::BF16:
+    case xevm::ElemType::F16:
+      return 2;
+    case xevm::ElemType::U8:
+    case xevm::ElemType::S8:
+      return 4;
+    default:
+      llvm_unreachable("unsupported xevm::ElemType");
+    }
+  }
+};
+
+class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
+    Value one = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));
+    SmallVector<Value> args{op.getPtr(), one};
+    SmallVector<Type> argTypes;
+    for (auto arg : args)
+      argTypes.push_back(arg.getType());
+    auto funcAttr = noUnwindAttrs;
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/LLVM::ModRefInfo::NoModRef,
+        /*argMem=*/LLVM::ModRefInfo::Ref,
+        /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+    funcAttr.memEffectsAttr = memAttr;
+
+    LLVM::CallOp call = createDeviceFunctionCall(
+        rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
+        argTypes, args, {}, funcAttr);
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata<true>(rewriter, op))
+      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    const std::string fnName{"atomic_work_item_fence"};
+    int memScope, addrSpace;
+    switch (op.getAddrspace()) {
+    case xevm::AddrSpace::SHARED:
+      addrSpace = 1; // CLK_LOCAL_MEM_FENCE
+      break;
+    case xevm::AddrSpace::GLOBAL:
+      addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
+      break;
+  ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

Changes

Co-authored-by: Artem Kroviakov artem.kroviakov@intel.com


Patch is 38.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147375.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+9)
  • (added) mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h (+29)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt (+21)
  • (added) mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp (+669)
  • (added) mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir (+83)
  • (modified) mlir/test/lib/Dialect/GPU/CMakeLists.txt (+1)
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..8a5976e547169 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -80,6 +80,7 @@
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
 
 namespace mlir {
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..929c3cf5a25ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1495,4 +1495,13 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// XeVMToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
+  let summary = "Convert XeVM to LLVM dialect";
+  let dependentDialects = ["xevm::XeVMDialect", ];
+}
+
 #endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
new file mode 100644
index 0000000000000..b361af573afff
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
@@ -0,0 +1,29 @@
+//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+} // namespace mlir
+
+namespace mlir {
+#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
+
+void registerConvertXeVMToLLVMInterface(DialectRegistry &registry);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index f356b91b1b6c0..47e2f554abb69 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
@@ -90,6 +91,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
+  registerConvertXeVMToLLVMInterface(registry);
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..24a48993ad80c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -73,3 +73,4 @@ add_subdirectory(VectorToLLVM)
 add_subdirectory(VectorToSCF)
 add_subdirectory(VectorToSPIRV)
 add_subdirectory(VectorToXeGPU)
+add_subdirectory(XeVMToLLVM)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..4ac60d8d43472
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRXeVMToLLVM
+  XeVMToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRGPUDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRXeVMDialect
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
new file mode 100644
index 0000000000000..89407825fd656
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -0,0 +1,669 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "xevm-to-llvm"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace xevm;
+
+namespace {
+
+struct LLVMFuncAttributeOptions {
+  bool isConvergent = false;
+  bool isNoUnwind = false;
+  bool isWillReturn = false;
+  LLVM::MemoryEffectsAttr memEffectsAttr{};
+};
+static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
+    false, true, false, {}};
+static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
+    false, true, true, {}};
+static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
+    true, true, true, {}};
+
+std::string getTypeMangling(Type ty, bool isUnsigned = false) {
+  return TypeSwitch<Type, std::string>(ty)
+      .Case([isUnsigned](VectorType ty) -> std::string {
+        return "Dv" + std::to_string(ty.getNumElements()) + "_" +
+               getTypeMangling(ty.getElementType(), isUnsigned);
+      })
+      .Case([](Float16Type) -> std::string { return "Dh"; })
+      .Case([](Float32Type) -> std::string { return "f"; })
+      .Case([](Float64Type) -> std::string { return "d"; })
+      .Case([isUnsigned](IntegerType ty) -> std::string {
+        switch (ty.getWidth()) {
+        case 8:
+          return isUnsigned ? "h" : "c";
+        case 16:
+          return isUnsigned ? "t" : "s";
+        case 32:
+          return isUnsigned ? "j" : "i";
+        case 64:
+          return isUnsigned ? "m" : "l";
+        default:
+          llvm_unreachable("unhandled integer type");
+        }
+      });
+}
+
+std::string mangle(StringRef baseName, ArrayRef<Type> types,
+                   ArrayRef<bool> isUnsigned = {}) {
+  assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
+         "Signedness info doesn't match");
+  std::string s;
+  llvm::raw_string_ostream os(s);
+  llvm::SmallDenseMap<Type, unsigned> substitutions;
+  os << "_Z" << baseName.size() << baseName;
+  for (auto [idx, type] : llvm::enumerate(types)) {
+    auto it = substitutions.find(type);
+    if (it != substitutions.end()) {
+      os << "S";
+      // First substitution is `S_`, second is `S0_`, and so on.
+      if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
+        os << firstIdx - 1;
+      os << "_";
+    } else {
+      if (!type.isIntOrFloat())
+        substitutions[type] = substitutions.size();
+      os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
+    }
+  }
+  return os.str();
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL1CacheControl(OpType op) {
+  int32_t control = 0;
+  if constexpr (isLoad) {
+    switch (*op.getCacheControl()) {
+    case LoadCacheControl::L1UC_L2UC_L3UC:
+    case LoadCacheControl::L1UC_L2UC_L3C:
+    case LoadCacheControl::L1UC_L2C_L3UC:
+    case LoadCacheControl::L1UC_L2C_L3C:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3UC:
+    case LoadCacheControl::L1C_L2UC_L3C:
+    case LoadCacheControl::L1C_L2C_L3UC:
+    case LoadCacheControl::L1C_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3UC:
+    case LoadCacheControl::L1S_L2UC_L3C:
+    case LoadCacheControl::L1S_L2C_L3UC:
+    case LoadCacheControl::L1S_L2C_L3C:
+      control = 3;
+      break;
+    case LoadCacheControl::INVALIDATE_READ:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  } else {
+    switch (*op.getCacheControl()) {
+    case StoreCacheControl::L1UC_L2UC_L3UC:
+    case StoreCacheControl::L1UC_L2UC_L3WB:
+    case StoreCacheControl::L1UC_L2WB_L3UC:
+    case StoreCacheControl::L1UC_L2WB_L3WB:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3UC:
+    case StoreCacheControl::L1WT_L2UC_L3WB:
+    case StoreCacheControl::L1WT_L2WB_L3UC:
+    case StoreCacheControl::L1WT_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3UC:
+    case StoreCacheControl::L1S_L2UC_L3WB:
+    case StoreCacheControl::L1S_L2WB_L3UC:
+    case StoreCacheControl::L1S_L2WB_L3WB:
+      control = 3;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3UC:
+    case StoreCacheControl::L1WB_L2WB_L3UC:
+    case StoreCacheControl::L1WB_L2UC_L3WB:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  }
+  return control;
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL3CacheControl(OpType op) {
+  int32_t control = 0;
+  if constexpr (isLoad) {
+    switch (*op.getCacheControl()) {
+    case LoadCacheControl::L1UC_L2UC_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1UC_L2UC_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1UC_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1UC_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1C_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1S_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::INVALIDATE_READ:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  } else {
+    switch (*op.getCacheControl()) {
+    case StoreCacheControl::L1UC_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1UC_L2UC_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1UC_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1UC_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1WT_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1S_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WB_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3WB:
+      control = 2;
+      break;
+    default:
+      break;
+    }
+  }
+  return control;
+}
+
+template <bool isLoad, typename OpType>
+static std::optional<ArrayAttr>
+getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
+  if constexpr (isLoad) {
+    if (!op.getCacheControl())
+      return {};
+  } else {
+    if (!op.getCacheControl())
+      return {};
+  }
+  constexpr int32_t decorationCacheControlArity{4};
+  constexpr int32_t loadCacheControlKey{6442};
+  constexpr int32_t storeCacheControlKey{6443};
+  const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
+  SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
+      controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
+  SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
+      controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
+  auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
+  auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
+
+  SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
+  return rewriter.getArrayAttr(combinedAttrs);
+}
+
+static LLVM::CallOp createDeviceFunctionCall(
+    ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
+    ArrayRef<Type> argTypes, ArrayRef<Value> args,
+    mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
+    LLVMFuncAttributeOptions funcAttributeOptions) {
+  auto moduleOp = rewriter.getBlock()
+                      ->getParentOp()
+                      ->getParentWithTrait<OpTrait::SymbolTable>();
+  assert(moduleOp && "Expecting module");
+  MLIRContext *ctx = rewriter.getContext();
+  Location loc = UnknownLoc::get(ctx);
+
+  auto funcOpRes =
+      LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
+  assert(!failed(funcOpRes));
+  LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+  funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+  funcOp.setConvergent(funcAttributeOptions.isConvergent);
+  funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
+  funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
+
+  if (funcAttributeOptions.memEffectsAttr)
+    funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
+
+  for (auto [idx, attrName] : paramAttrs)
+    funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
+
+  auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+  callOp->setAttrs(funcOp->getAttrs());
+
+  return callOp;
+}
+
+class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op.getC()) {
+      return rewriter.notifyMatchFailure(op, "OCL requires C operand");
+    }
+    constexpr uint32_t bitWidthPackedA{16};
+    constexpr uint32_t bitWidthPackedB{32};
+    auto loc = op.getLoc();
+
+    auto castIfNeeded = [&](Value val, Type packedType) -> Value {
+      VectorType origTy = cast<VectorType>(val.getType());
+      const uint32_t vecBitSize =
+          origTy.getNumElements() *
+          origTy.getElementType().getIntOrFloatBitWidth();
+      VectorType newTy = VectorType::get(
+          vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
+      if (origTy != newTy)
+        val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+      return val;
+    };
+
+    Value a = op.getA();
+    Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
+                           ? cast<Type>(rewriter.getF32Type())
+                           : rewriter.getIntegerType(bitWidthPackedA);
+    a = castIfNeeded(a, packedAType);
+
+    Value b = op.getB();
+    Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
+                           ? cast<Type>(rewriter.getF32Type())
+                           : rewriter.getIntegerType(bitWidthPackedB);
+    b = castIfNeeded(b, packedBType);
+
+    Value c = op.getC();
+    VectorType cOrigTy = cast<VectorType>(c.getType());
+    assert(cOrigTy == op->getResultTypes()[0] &&
+           "Accumulator and result type mismatch");
+    // OCL builtins encode bfloat16 as int16
+    VectorType cTy =
+        cOrigTy.getElementType().isBF16()
+            ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
+            : cOrigTy;
+    if (cOrigTy != cTy)
+      c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
+
+    constexpr int32_t systolicDepth{8};
+    std::string fnName =
+        llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
+                      stringifyElemType(op.getTypes().getA()).str(),
+                      stringifyElemType(op.getTypes().getB()).str(),
+                      systolicDepth *
+                          getNumOperandsPerDword(op.getTypes().getA()))
+            .str();
+    SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
+    fnName = mangle(fnName, argTypes);
+    SmallVector<Value> args{a, b, c};
+
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/LLVM::ModRefInfo::NoModRef,
+        /*argMem=*/LLVM::ModRefInfo::NoModRef,
+        /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+    auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+    funcAttrs.memEffectsAttr = memAttr;
+    Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
+                                            args, {}, funcAttrs)
+                       ->getResult(0);
+
+    if (cOrigTy != cTy)
+      result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+    switch (pTy) {
+    case xevm::ElemType::TF32:
+      return 1;
+    case xevm::ElemType::BF16:
+    case xevm::ElemType::F16:
+      return 2;
+    case xevm::ElemType::U8:
+    case xevm::ElemType::S8:
+      return 4;
+    default:
+      llvm_unreachable("unsupported xevm::ElemType");
+    }
+  }
+};
+
+class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
+    Value one = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));
+    SmallVector<Value> args{op.getPtr(), one};
+    SmallVector<Type> argTypes;
+    for (auto arg : args)
+      argTypes.push_back(arg.getType());
+    auto funcAttr = noUnwindAttrs;
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/LLVM::ModRefInfo::NoModRef,
+        /*argMem=*/LLVM::ModRefInfo::Ref,
+        /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+    funcAttr.memEffectsAttr = memAttr;
+
+    LLVM::CallOp call = createDeviceFunctionCall(
+        rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
+        argTypes, args, {}, funcAttr);
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata<true>(rewriter, op))
+      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    const std::string fnName{"atomic_work_item_fence"};
+    int memScope, addrSpace;
+    switch (op.getAddrspace()) {
+    case xevm::AddrSpace::SHARED:
+      addrSpace = 1; // CLK_LOCAL_MEM_FENCE
+      break;
+    case xevm::AddrSpace::GLOBAL:
+      addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
+      break;
+  ...
[truncated]

@silee2
Copy link
Contributor Author

silee2 commented Jul 7, 2025

@rengolin rengolin requested review from adam-smnk and Jianhui-Li July 7, 2025 21:22
@Garra1980
Copy link

also cc @akroviakov

template <bool isLoad, typename OpType>
static std::optional<ArrayAttr>
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
if constexpr (isLoad) {
Copy link
Contributor

@akroviakov akroviakov Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: both branches seem to be the same. Could one reduce this check to

    if (!op.getCacheControl())
      return {};

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Didn't realize they are the same.
Optimized as suggested.

Value memScopeConst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(memScope));
Value addrSpaceConst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(addrSpace));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: IntegerAttr should be avoidable for LLVM::ConstantOp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not in the translation?

@adam-smnk
Copy link
Contributor

adam-smnk commented Jul 8, 2025

Why is this not in the translation?

This is neither lowering to built-in LLVM intrinsic (they are proper intrinsics but not standalone LLVM ops) nor using any automatically generated conversions.
It's better to retain this staging ground between xevm and LLVM proper through lowering to MLIR LLVMIR first.

Conversion to llvm seems like suitable abstraction. Similarly to how x86vector or amx intrinsics are generated today through opaque LLVMIR function calls which are fully translated to LLVM later.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good from perspective of infrastructure and overall flow

Minor comments about the lowering logic itself
I'll leave in-depth review of that to @Jianhui-Li

private:
static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
switch (pTy) {
case xevm::ElemType::TF32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no f32?
It seems valid looking at the op definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f32 is valid in op definition for C in the formula for MMA below.
D= A*B+C
A and B does not support f32
getNumOperandsPerDword is a function for types supported by A and B if sub dword (or 32bit) types need to be packed/bundled into a dword. As the name implies, it returns how many operands of pTy is packed per dword.

Copy link
Contributor

@adam-smnk adam-smnk Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A and B does not support f32

In this case, the op definition has to be further constrained. At the moment, xevm.mma is happy to take f32 for all its operands.

Copy link
Contributor Author

@silee2 silee2 Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad.
MMA op can have mismatch in

  • operand element type and
  • types declared in attribute

For example,

%d = xevm.mma %a, %b, %c { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> }
             : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>

i16 and i32 are packed type.

f32 can appear in A or B if the true type or declared type in attribute is tf32
packed type for tf32 is f32

In short, f32 is a valid element type for operand A or B

Copy link
Contributor

@adam-smnk adam-smnk Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My larger point is that, for example, this op is valid today:

func.func @mma_f32(%c: vector<8xf32>, %a: vector<8xf32>,
    %b: vector<8xf32>) -> vector<8xf32> {
  %d = xevm.mma %a, %b, %c
  { shape=<m=8, n=16, k=16>,
    types=<d=f32, a=f32, b=f32, c=f32> }
    : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
  return %d : vector<8xf32>
}

Ideally, unsupported combinations would be caught by op verifier as it's nice to have that as an invariant.
If there's a potential use case for such combination (some rewrites, canonicalization etc. directly on xevm ops), then please at least add extra checks to this pass so it doesn't blow up on an assertion.

// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind,
// CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems from the code that the cache hint is attached as attribute to the call. Is this being checked ?

if (std::optional<ArrayAttr> optCacheControls =
        getCacheControlMetadata<true>(rewriter, op))
  call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is.
But not part of this PR. Metadata is consumed during translation.
That will be done in a follow up PR but I can provide a link to code in the POC branch consumes the metadata.
https://github.com/silee2/llvm-project/blob/xevmWorkspace/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp#L50-L72

// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add more tests to cover a few combination of these parameters, say the following:

  1. load 32x16xi16, with v_blocks =2
  2. load 16x16xi32, transpose=true
  3. load 16x16xi16, pack_register=true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added v_blocks, transpose and pack_register and transpose test cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Can you also change the shape size as described above also?

auto funcAttrs = convergentNoUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
Value result =
createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, args, {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider rename cTy to dTy since this parameter specify return type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, there are a few warning in CI. Please address them before merging.

private:
static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
switch (pTy) {
case xevm::ElemType::TF32:
Copy link
Contributor

@adam-smnk adam-smnk Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My larger point is that, for example, this op is valid today:

func.func @mma_f32(%c: vector<8xf32>, %a: vector<8xf32>,
    %b: vector<8xf32>) -> vector<8xf32> {
  %d = xevm.mma %a, %b, %c
  { shape=<m=8, n=16, k=16>,
    types=<d=f32, a=f32, b=f32, c=f32> }
    : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
  return %d : vector<8xf32>
}

Ideally, unsupported combinations would be caught by op verifier as it's nice to have that as an invariant.
If there's a potential use case for such combination (some rewrites, canonicalization etc. directly on xevm ops), then please at least add extra checks to this pass so it doesn't blow up on an assertion.

@silee2
Copy link
Contributor Author

silee2 commented Jul 10, 2025

Ideally, unsupported combinations would be caught by op verifier as it's nice to have that as an invariant.
If there's a potential use case for such combination (some rewrites, canonicalization etc. directly on xevm ops), then please at least add extra checks to this pass so it doesn't blow up on an assertion.

Valid point. I'll keep that in mind.
Unsupported combination sometimes depends on target uArch and available backend intrinsics.
For now, op verifier tries to validate common invariants that hold regardless.
Meanwhile, I'll add extra checks to the pass as you have suggested until there is a better solution.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants