From 915ca433b998734a1efd9f8462ef44b194665c99 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Mon, 7 Jul 2025 18:49:14 +0000 Subject: [PATCH 01/11] Add convert-xevm-to-llvm pass. Co-authored-by: Artem Kroviakov artem.kroviakov@intel.com --- mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 9 + .../mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h | 29 + mlir/include/mlir/InitAllExtensions.h | 2 + mlir/lib/Conversion/CMakeLists.txt | 1 + mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt | 21 + mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 669 ++++++++++++++++++ .../Conversion/XeVMToLLVM/xevm-to-llvm.mlir | 83 +++ 8 files changed, 815 insertions(+) create mode 100644 mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h create mode 100644 mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt create mode 100644 mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp create mode 100644 mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index c9d2a54433736..8a5976e547169 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -80,6 +80,7 @@ #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h" #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 5a864865adffc..929c3cf5a25ed 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1495,4 +1495,13 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> { ]; } +//===----------------------------------------------------------------------===// +// XeVMToLLVM +//===----------------------------------------------------------------------===// + +def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> { + let summary = "Convert XeVM to LLVM dialect"; + let dependentDialects = ["xevm::XeVMDialect", ]; +} + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h new file mode 100644 index 0000000000000..b361af573afff --- /dev/null +++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h @@ -0,0 +1,29 @@ +//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_ +#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_ + +#include + +namespace mlir { +class DialectRegistry; +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; +} // namespace mlir + +namespace mlir { +#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" + +void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns); + +void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry); +} // namespace mlir + +#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_ diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 0f2d0e45008cc..d5a9a2c3aeba7 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -32,6 +32,7 @@ #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" @@ -91,6 +92,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { gpu::registerConvertGpuToLLVMInterface(registry); NVVM::registerConvertGpuToNVVMInterface(registry); vector::registerConvertVectorToLLVMInterface(registry); + registerConvertXeVMToLLVMInterface(registry); // Register all transform dialect extensions. affine::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index e4b4974600577..24a48993ad80c 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -73,3 +73,4 @@ add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) add_subdirectory(VectorToXeGPU) +add_subdirectory(XeVMToLLVM) diff --git a/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt new file mode 100644 index 0000000000000..4ac60d8d43472 --- /dev/null +++ b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(MLIRXeVMToLLVM + XeVMToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRGPUDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRXeVMDialect + MLIRPass + MLIRTransforms +) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp new file mode 100644 index 0000000000000..89407825fd656 --- /dev/null +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -0,0 +1,669 @@ +//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/FormatVariadic.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "xevm-to-llvm" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace xevm; + +namespace { + +struct LLVMFuncAttributeOptions { + bool isConvergent = false; + bool isNoUnwind = false; + bool isWillReturn = false; + LLVM::MemoryEffectsAttr memEffectsAttr{}; +}; +static constexpr LLVMFuncAttributeOptions noUnwindAttrs = { + false, true, false, {}}; +static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = { + false, true, true, {}}; +static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = { + true, true, true, {}}; + +std::string getTypeMangling(Type ty, bool isUnsigned = false) { + return TypeSwitch(ty) + .Case([isUnsigned](VectorType ty) -> std::string { + return "Dv" + std::to_string(ty.getNumElements()) + "_" + + getTypeMangling(ty.getElementType(), isUnsigned); + }) + .Case([](Float16Type) -> std::string { return "Dh"; }) + .Case([](Float32Type) -> std::string { return "f"; }) + .Case([](Float64Type) -> std::string { return "d"; }) + .Case([isUnsigned](IntegerType ty) -> std::string { + switch (ty.getWidth()) { + case 8: + return isUnsigned ? "h" : "c"; + case 16: + return isUnsigned ? "t" : "s"; + case 32: + return isUnsigned ? "j" : "i"; + case 64: + return isUnsigned ? "m" : "l"; + default: + llvm_unreachable("unhandled integer type"); + } + }); +} + +std::string mangle(StringRef baseName, ArrayRef types, + ArrayRef isUnsigned = {}) { + assert((isUnsigned.empty() || isUnsigned.size() == types.size()) && + "Signedness info doesn't match"); + std::string s; + llvm::raw_string_ostream os(s); + llvm::SmallDenseMap substitutions; + os << "_Z" << baseName.size() << baseName; + for (auto [idx, type] : llvm::enumerate(types)) { + auto it = substitutions.find(type); + if (it != substitutions.end()) { + os << "S"; + // First substitution is `S_`, second is `S0_`, and so on. + if (unsigned firstIdx = it->getSecond(); firstIdx > 0) + os << firstIdx - 1; + os << "_"; + } else { + if (!type.isIntOrFloat()) + substitutions[type] = substitutions.size(); + os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]); + } + } + return os.str(); +} + +template +int32_t getL1CacheControl(OpType op) { + int32_t control = 0; + if constexpr (isLoad) { + switch (*op.getCacheControl()) { + case LoadCacheControl::L1UC_L2UC_L3UC: + case LoadCacheControl::L1UC_L2UC_L3C: + case LoadCacheControl::L1UC_L2C_L3UC: + case LoadCacheControl::L1UC_L2C_L3C: + control = 1; + break; + case LoadCacheControl::L1C_L2UC_L3UC: + case LoadCacheControl::L1C_L2UC_L3C: + case LoadCacheControl::L1C_L2C_L3UC: + case LoadCacheControl::L1C_L2C_L3C: + control = 2; + break; + case LoadCacheControl::L1S_L2UC_L3UC: + case LoadCacheControl::L1S_L2UC_L3C: + case LoadCacheControl::L1S_L2C_L3UC: + case LoadCacheControl::L1S_L2C_L3C: + control = 3; + break; + case LoadCacheControl::INVALIDATE_READ: + control = 4; + break; + default: + break; + } + } else { + switch (*op.getCacheControl()) { + case StoreCacheControl::L1UC_L2UC_L3UC: + case StoreCacheControl::L1UC_L2UC_L3WB: + case StoreCacheControl::L1UC_L2WB_L3UC: + case StoreCacheControl::L1UC_L2WB_L3WB: + control = 1; + break; + case StoreCacheControl::L1WT_L2UC_L3UC: + case StoreCacheControl::L1WT_L2UC_L3WB: + case StoreCacheControl::L1WT_L2WB_L3UC: + case StoreCacheControl::L1WT_L2WB_L3WB: + control = 2; + break; + case StoreCacheControl::L1S_L2UC_L3UC: + case StoreCacheControl::L1S_L2UC_L3WB: + case StoreCacheControl::L1S_L2WB_L3UC: + case StoreCacheControl::L1S_L2WB_L3WB: + control = 3; + break; + case StoreCacheControl::L1WB_L2UC_L3UC: + case StoreCacheControl::L1WB_L2WB_L3UC: + case StoreCacheControl::L1WB_L2UC_L3WB: + control = 4; + break; + default: + break; + } + } + return control; +} + +template +int32_t getL3CacheControl(OpType op) { + int32_t control = 0; + if constexpr (isLoad) { + switch (*op.getCacheControl()) { + case LoadCacheControl::L1UC_L2UC_L3UC: + control = 1; + break; + case LoadCacheControl::L1UC_L2UC_L3C: + control = 2; + break; + case LoadCacheControl::L1UC_L2C_L3UC: + control = 1; + break; + case LoadCacheControl::L1UC_L2C_L3C: + control = 2; + break; + case LoadCacheControl::L1C_L2UC_L3UC: + control = 1; + break; + case LoadCacheControl::L1C_L2UC_L3C: + control = 2; + break; + case LoadCacheControl::L1C_L2C_L3UC: + control = 1; + break; + case LoadCacheControl::L1C_L2C_L3C: + control = 2; + break; + case LoadCacheControl::L1S_L2UC_L3UC: + control = 1; + break; + case LoadCacheControl::L1S_L2UC_L3C: + control = 2; + break; + case LoadCacheControl::L1S_L2C_L3UC: + control = 1; + break; + case LoadCacheControl::L1S_L2C_L3C: + control = 2; + break; + case LoadCacheControl::INVALIDATE_READ: + control = 4; + break; + default: + break; + } + } else { + switch (*op.getCacheControl()) { + case StoreCacheControl::L1UC_L2UC_L3UC: + control = 1; + break; + case StoreCacheControl::L1UC_L2UC_L3WB: + control = 2; + break; + case StoreCacheControl::L1UC_L2WB_L3UC: + control = 1; + break; + case StoreCacheControl::L1UC_L2WB_L3WB: + control = 2; + break; + case StoreCacheControl::L1WT_L2UC_L3UC: + control = 1; + break; + case StoreCacheControl::L1WT_L2UC_L3WB: + control = 2; + break; + case StoreCacheControl::L1WT_L2WB_L3UC: + control = 1; + break; + case StoreCacheControl::L1WT_L2WB_L3WB: + control = 2; + break; + case StoreCacheControl::L1S_L2UC_L3UC: + control = 1; + break; + case StoreCacheControl::L1S_L2UC_L3WB: + control = 2; + break; + case StoreCacheControl::L1S_L2WB_L3UC: + control = 1; + break; + case StoreCacheControl::L1S_L2WB_L3WB: + control = 2; + break; + case StoreCacheControl::L1WB_L2UC_L3UC: + control = 1; + break; + case StoreCacheControl::L1WB_L2WB_L3UC: + control = 1; + break; + case StoreCacheControl::L1WB_L2UC_L3WB: + control = 2; + break; + default: + break; + } + } + return control; +} + +template +static std::optional +getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { + if constexpr (isLoad) { + if (!op.getCacheControl()) + return {}; + } else { + if (!op.getCacheControl()) + return {}; + } + constexpr int32_t decorationCacheControlArity{4}; + constexpr int32_t loadCacheControlKey{6442}; + constexpr int32_t storeCacheControlKey{6443}; + const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; + SmallVector decorationsL1{ + controlKey, 0, getL1CacheControl(op), 0}; + SmallVector decorationsL3{ + controlKey, 1, getL3CacheControl(op), 0}; + auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1); + auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3); + + SmallVector combinedAttrs = {arrayAttrL1, arrayAttrL3}; + return rewriter.getArrayAttr(combinedAttrs); +} + +static LLVM::CallOp createDeviceFunctionCall( + ConversionPatternRewriter &rewriter, StringRef funcName, Type retType, + ArrayRef argTypes, ArrayRef args, + mlir::ArrayRef> paramAttrs, + LLVMFuncAttributeOptions funcAttributeOptions) { + auto moduleOp = rewriter.getBlock() + ->getParentOp() + ->getParentWithTrait(); + assert(moduleOp && "Expecting module"); + MLIRContext *ctx = rewriter.getContext(); + Location loc = UnknownLoc::get(ctx); + + auto funcOpRes = + LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType); + assert(!failed(funcOpRes)); + LLVM::LLVMFuncOp funcOp = funcOpRes.value(); + funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + funcOp.setConvergent(funcAttributeOptions.isConvergent); + funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind); + funcOp.setWillReturn(funcAttributeOptions.isWillReturn); + + if (funcAttributeOptions.memEffectsAttr) + funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr); + + for (auto [idx, attrName] : paramAttrs) + funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr()); + + auto callOp = rewriter.create(loc, funcOp, args); + callOp->setAttrs(funcOp->getAttrs()); + + return callOp; +} + +class MMAToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getC()) { + return rewriter.notifyMatchFailure(op, "OCL requires C operand"); + } + constexpr uint32_t bitWidthPackedA{16}; + constexpr uint32_t bitWidthPackedB{32}; + auto loc = op.getLoc(); + + auto castIfNeeded = [&](Value val, Type packedType) -> Value { + VectorType origTy = cast(val.getType()); + const uint32_t vecBitSize = + origTy.getNumElements() * + origTy.getElementType().getIntOrFloatBitWidth(); + VectorType newTy = VectorType::get( + vecBitSize / packedType.getIntOrFloatBitWidth(), packedType); + if (origTy != newTy) + val = rewriter.create(loc, newTy, val); + return val; + }; + + Value a = op.getA(); + Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32) + ? cast(rewriter.getF32Type()) + : rewriter.getIntegerType(bitWidthPackedA); + a = castIfNeeded(a, packedAType); + + Value b = op.getB(); + Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32) + ? cast(rewriter.getF32Type()) + : rewriter.getIntegerType(bitWidthPackedB); + b = castIfNeeded(b, packedBType); + + Value c = op.getC(); + VectorType cOrigTy = cast(c.getType()); + assert(cOrigTy == op->getResultTypes()[0] && + "Accumulator and result type mismatch"); + // OCL builtins encode bfloat16 as int16 + VectorType cTy = + cOrigTy.getElementType().isBF16() + ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16)) + : cOrigTy; + if (cOrigTy != cTy) + c = rewriter.create(loc, cTy, c); + + constexpr int32_t systolicDepth{8}; + std::string fnName = + llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}", + stringifyElemType(op.getTypes().getA()).str(), + stringifyElemType(op.getTypes().getB()).str(), + systolicDepth * + getNumOperandsPerDword(op.getTypes().getA())) + .str(); + SmallVector argTypes{a.getType(), b.getType(), cTy}; + fnName = mangle(fnName, argTypes); + SmallVector args{a, b, c}; + + auto memAttr = rewriter.getAttr( + /*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::NoModRef, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + auto funcAttrs = convergentNoUnwindWillReturnAttrs; + funcAttrs.memEffectsAttr = memAttr; + Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, + args, {}, funcAttrs) + ->getResult(0); + + if (cOrigTy != cTy) + result = rewriter.create(loc, cOrigTy, result); + + rewriter.replaceOp(op, result); + return success(); + } + +private: + static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { + switch (pTy) { + case xevm::ElemType::TF32: + return 1; + case xevm::ElemType::BF16: + case xevm::ElemType::F16: + return 2; + case xevm::ElemType::U8: + case xevm::ElemType::S8: + return 4; + default: + llvm_unreachable("unsupported xevm::ElemType"); + } + } +}; + +class PrefetchToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const std::string fnName{"_Z8prefetchPU3AS1Kcm"}; + Value one = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); + SmallVector args{op.getPtr(), one}; + SmallVector argTypes; + for (auto arg : args) + argTypes.push_back(arg.getType()); + auto funcAttr = noUnwindAttrs; + auto memAttr = rewriter.getAttr( + /*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::Ref, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + funcAttr.memEffectsAttr = memAttr; + + LLVM::CallOp call = createDeviceFunctionCall( + rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()), + argTypes, args, {}, funcAttr); + if (std::optional optCacheControls = + getCacheControlMetadata(rewriter, op)) + call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); + rewriter.eraseOp(op); + return success(); + } +}; + +class MemfenceToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const std::string fnName{"atomic_work_item_fence"}; + int memScope, addrSpace; + switch (op.getAddrspace()) { + case xevm::AddrSpace::SHARED: + addrSpace = 1; // CLK_LOCAL_MEM_FENCE + break; + case xevm::AddrSpace::GLOBAL: + addrSpace = 2; // CLK_GLOBAL_MEM_FENCE + break; + default: + // GENERIC is not supported in OpenCL + llvm_unreachable("Fence only supports global and shared address spaces."); + } + switch (op.getScope()) { + case xevm::MemScope::WORKGROUP: + memScope = 1; + break; + case xevm::MemScope::DEVICE: + memScope = 2; + break; + default: + // CLUSTER and SYSTEM are not supported in OpenCL + llvm_unreachable("unsupported xevm::MemoryScope"); + } + Value acqRel = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(4)); + Value memScopeConst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(memScope)); + Value addrSpaceConst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(addrSpace)); + SmallVector args{addrSpaceConst, acqRel, memScopeConst}; + SmallVector argTypes{3, rewriter.getI32Type()}; + createDeviceFunctionCall(rewriter, mangle(fnName, argTypes), + LLVM::LLVMVoidType::get(rewriter.getContext()), + argTypes, args, {}, noUnwindAttrs); + rewriter.eraseOp(op); + return success(); + } +}; +template +class LoadStorePrefetchToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr bool isLoad = std::is_same_v; + constexpr bool isPrefetch = std::is_same_v; + + auto loc = op.getLoc(); + VectorType vecType; + bool packReg = false; + bool transpose = false; + if constexpr (isLoad) { + vecType = op.getRes().getType(); + packReg = op.getPackRegister(); + transpose = op.getTranspose(); + } else if constexpr (!isPrefetch) { + vecType = op.getStoredVal().getType(); + } + + auto i32Type = rewriter.getI32Type(); + Value byteCoord = + rewriter.create(loc, VectorType::get(2, i32Type)); + Value zero = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(0)); + Value one = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(1)); + byteCoord = rewriter.create( + loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); + byteCoord = rewriter.create( + loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); + SmallVector args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(), + op.getBasePitch(), byteCoord}; + SmallVector retTypes; + Value spvLoadDstPtr; + std::string funcName{"intel_sub_group_2d_block_"}; + std::string bitWidthId; + LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs}; + SmallVector, 4> paramAttrs; + if constexpr (isPrefetch) { // Prefetch + funcName += "prefetch"; + paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())}; + auto memAttr = rewriter.getAttr( + /*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::Ref, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + funcAttr = noUnwindAttrs; + funcAttr.memEffectsAttr = memAttr; + } else { + auto vecElemType = vecType.getElementType(); + auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth(); + Value numElems = rewriter.create( + loc, i32Type, vecType.getNumElements()); + auto dstOrSrcPtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType, + numElems); + args.push_back(dstOrSrcPtr); + if constexpr (isLoad) { // Load + funcName += "read"; + bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true); + if (packReg) + funcName += "_transform"; + else if (transpose) + funcName += "_transpose"; + spvLoadDstPtr = dstOrSrcPtr; + retTypes.push_back(vecType); + paramAttrs = { + std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()), + }; + } else { // Store + funcName += "write"; + bitWidthId = (vecElemBitWidth == 32) + ? "j" + : ((vecElemBitWidth == 16) ? "t" : "h"); + rewriter.create(loc, op.getStoredVal(), dstOrSrcPtr); + paramAttrs = { + std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()), + }; + } + } + + funcName = + llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(), + op.getTileHeight(), op.getTileWidth(), op.getVBlocks()) + .str(); + funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(), + funcName, isPrefetch ? "" : "P", bitWidthId) + .str(); + SmallVector argTypes; + for (auto arg : args) { + argTypes.push_back(arg.getType()); + } + LLVM::CallOp call = createDeviceFunctionCall( + rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()), + argTypes, args, paramAttrs, funcAttr); + if (std::optional optCacheControls = + getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) { + call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); + } + if constexpr (isLoad) + rewriter.replaceOp( + op, rewriter.create(loc, vecType, spvLoadDstPtr)); + else + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct ConvertXeVMToLLVMPass + : public impl::ConvertXeVMToLLVMPassBase { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + RewritePatternSet patterns(&getContext()); + populateXeVMToLLVMConversionPatterns(patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert XeVM to LLVM. +struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateXeVMToLLVMConversionPatterns(patterns); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) { + patterns.add, + LoadStorePrefetchToOCLPattern, + LoadStorePrefetchToOCLPattern, + MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>( + patterns.getContext()); +} + +void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir new file mode 100644 index 0000000000000..2d9a9d5683756 --- /dev/null +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s + +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s + +// CHECK: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -> vector<8xi16> { +llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, linkage = #llvm.linkage, no_unwind, sym_name = "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, will_return} : (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return %loaded_a : vector<8xi16> +} + +// ----- +// CHECK: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) { +llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG6]], %[[VAR6]] : vector<8xi32>, !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, linkage = #llvm.linkage, no_unwind, sym_name = "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, will_return} : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) -> () + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + llvm.return +} + +// ----- +// CHECK: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects, no_unwind} +// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) { +llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) {function_type = !llvm.func, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage, memory_effects = #llvm.memory_effects, no_unwind, sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64 + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- +// CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} +// CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> { +llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { + // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type = !llvm.func (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage, memory_effects = #llvm.memory_effects, no_unwind, sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted { shape=, types= } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + llvm.return %c_result : vector<8xf32> +} + +// ----- +// CHECK: llvm.func spir_funccc @_Z22atomic_work_item_fenceiii(i32, i32, i32) attributes {no_unwind} +llvm.func @memfence() { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]]) {function_type = !llvm.func, linkage = #llvm.linkage, no_unwind, sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> () + xevm.memfence <{addrspace=#xevm.addr_space, scope=#xevm.mem_scope}> + llvm.return +} + +// ----- +// CHECK: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes {memory_effects = #llvm.memory_effects, no_unwind} +// CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) { +llvm.func @prefetch(%ptr: !llvm.ptr<1>) { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]]) {function_type = !llvm.func, i64)>, linkage = #llvm.linkage, memory_effects = #llvm.memory_effects, no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64 + xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + llvm.return +} + From 11175c6a5e06a48f49df2a18fc33498f60a5d475 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 9 Jul 2025 03:54:57 +0000 Subject: [PATCH 02/11] Address reviewer comments. --- mlir/include/mlir/Conversion/Passes.td | 2 +- .../mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h | 2 - mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 125 ++++++------------ .../Conversion/XeVMToLLVM/xevm-to-llvm.mlir | 86 +++++++++--- 4 files changed, 106 insertions(+), 109 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 929c3cf5a25ed..78e68632409d8 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1501,7 +1501,7 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> { def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> { let summary = "Convert XeVM to LLVM dialect"; - let dependentDialects = ["xevm::XeVMDialect", ]; + let dependentDialects = ["LLVM::LLVMDialect", ]; } #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h index b361af573afff..7ffdbd4307f9e 100644 --- a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h +++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h @@ -15,9 +15,7 @@ class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; class Pass; -} // namespace mlir -namespace mlir { #define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 89407825fd656..4605ef78ee50d 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -23,8 +23,6 @@ #include "llvm/ADT/TypeSwitch.h" -#define DEBUG_TYPE "xevm-to-llvm" - namespace mlir { #define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" @@ -70,6 +68,9 @@ std::string getTypeMangling(Type ty, bool isUnsigned = false) { default: llvm_unreachable("unhandled integer type"); } + }) + .Default([](Type) -> std::string { + llvm_unreachable("unhandled type for mangling"); }); } @@ -165,38 +166,18 @@ int32_t getL3CacheControl(OpType op) { if constexpr (isLoad) { switch (*op.getCacheControl()) { case LoadCacheControl::L1UC_L2UC_L3UC: - control = 1; - break; - case LoadCacheControl::L1UC_L2UC_L3C: - control = 2; - break; case LoadCacheControl::L1UC_L2C_L3UC: - control = 1; - break; - case LoadCacheControl::L1UC_L2C_L3C: - control = 2; - break; case LoadCacheControl::L1C_L2UC_L3UC: - control = 1; - break; - case LoadCacheControl::L1C_L2UC_L3C: - control = 2; - break; case LoadCacheControl::L1C_L2C_L3UC: - control = 1; - break; - case LoadCacheControl::L1C_L2C_L3C: - control = 2; - break; case LoadCacheControl::L1S_L2UC_L3UC: - control = 1; - break; - case LoadCacheControl::L1S_L2UC_L3C: - control = 2; - break; case LoadCacheControl::L1S_L2C_L3UC: control = 1; break; + case LoadCacheControl::L1UC_L2UC_L3C: + case LoadCacheControl::L1UC_L2C_L3C: + case LoadCacheControl::L1C_L2UC_L3C: + case LoadCacheControl::L1C_L2C_L3C: + case LoadCacheControl::L1S_L2UC_L3C: case LoadCacheControl::L1S_L2C_L3C: control = 2; break; @@ -209,47 +190,21 @@ int32_t getL3CacheControl(OpType op) { } else { switch (*op.getCacheControl()) { case StoreCacheControl::L1UC_L2UC_L3UC: - control = 1; - break; - case StoreCacheControl::L1UC_L2UC_L3WB: - control = 2; - break; case StoreCacheControl::L1UC_L2WB_L3UC: - control = 1; - break; - case StoreCacheControl::L1UC_L2WB_L3WB: - control = 2; - break; case StoreCacheControl::L1WT_L2UC_L3UC: - control = 1; - break; - case StoreCacheControl::L1WT_L2UC_L3WB: - control = 2; - break; case StoreCacheControl::L1WT_L2WB_L3UC: - control = 1; - break; - case StoreCacheControl::L1WT_L2WB_L3WB: - control = 2; - break; case StoreCacheControl::L1S_L2UC_L3UC: - control = 1; - break; - case StoreCacheControl::L1S_L2UC_L3WB: - control = 2; - break; case StoreCacheControl::L1S_L2WB_L3UC: - control = 1; - break; - case StoreCacheControl::L1S_L2WB_L3WB: - control = 2; - break; case StoreCacheControl::L1WB_L2UC_L3UC: - control = 1; - break; case StoreCacheControl::L1WB_L2WB_L3UC: control = 1; break; + case StoreCacheControl::L1UC_L2UC_L3WB: + case StoreCacheControl::L1UC_L2WB_L3WB: + case StoreCacheControl::L1WT_L2UC_L3WB: + case StoreCacheControl::L1WT_L2WB_L3WB: + case StoreCacheControl::L1S_L2UC_L3WB: + case StoreCacheControl::L1S_L2WB_L3WB: case StoreCacheControl::L1WB_L2UC_L3WB: control = 2; break; @@ -263,13 +218,8 @@ int32_t getL3CacheControl(OpType op) { template static std::optional getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { - if constexpr (isLoad) { - if (!op.getCacheControl()) - return {}; - } else { - if (!op.getCacheControl()) - return {}; - } + if (!op.getCacheControl()) + return {}; constexpr int32_t decorationCacheControlArity{4}; constexpr int32_t loadCacheControlKey{6442}; constexpr int32_t storeCacheControlKey{6443}; @@ -289,13 +239,12 @@ static LLVM::CallOp createDeviceFunctionCall( ConversionPatternRewriter &rewriter, StringRef funcName, Type retType, ArrayRef argTypes, ArrayRef args, mlir::ArrayRef> paramAttrs, - LLVMFuncAttributeOptions funcAttributeOptions) { + LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) { auto moduleOp = rewriter.getBlock() ->getParentOp() ->getParentWithTrait(); assert(moduleOp && "Expecting module"); - MLIRContext *ctx = rewriter.getContext(); - Location loc = UnknownLoc::get(ctx); + Location loc = op->getLoc(); auto funcOpRes = LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType); @@ -384,9 +333,10 @@ class MMAToOCLPattern : public OpConversionPattern { /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); auto funcAttrs = convergentNoUnwindWillReturnAttrs; funcAttrs.memEffectsAttr = memAttr; - Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, - args, {}, funcAttrs) - ->getResult(0); + Value result = + createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, args, {}, + funcAttrs, op.getOperation()) + ->getResult(0); if (cOrigTy != cTy) result = rewriter.create(loc, cOrigTy, result); @@ -419,8 +369,8 @@ class PrefetchToOCLPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); const std::string fnName{"_Z8prefetchPU3AS1Kcm"}; - Value one = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); + Value one = + rewriter.create(loc, rewriter.getI64Type(), 1); SmallVector args{op.getPtr(), one}; SmallVector argTypes; for (auto arg : args) @@ -434,7 +384,7 @@ class PrefetchToOCLPattern : public OpConversionPattern { LLVM::CallOp call = createDeviceFunctionCall( rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()), - argTypes, args, {}, funcAttr); + argTypes, args, {}, funcAttr, op.getOperation()); if (std::optional optCacheControls = getCacheControlMetadata(rewriter, op)) call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); @@ -473,17 +423,18 @@ class MemfenceToOCLPattern : public OpConversionPattern { // CLUSTER and SYSTEM are not supported in OpenCL llvm_unreachable("unsupported xevm::MemoryScope"); } - Value acqRel = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(4)); - Value memScopeConst = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(memScope)); - Value addrSpaceConst = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(addrSpace)); + Type i32Type = rewriter.getI32Type(); + Value acqRel = rewriter.create(loc, i32Type, 4); + Value memScopeConst = + rewriter.create(loc, i32Type, memScope); + Value addrSpaceConst = + rewriter.create(loc, i32Type, addrSpace); SmallVector args{addrSpaceConst, acqRel, memScopeConst}; - SmallVector argTypes{3, rewriter.getI32Type()}; + SmallVector argTypes{3, i32Type}; createDeviceFunctionCall(rewriter, mangle(fnName, argTypes), LLVM::LLVMVoidType::get(rewriter.getContext()), - argTypes, args, {}, noUnwindAttrs); + argTypes, args, {}, noUnwindAttrs, + op.getOperation()); rewriter.eraseOp(op); return success(); } @@ -512,10 +463,8 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern { auto i32Type = rewriter.getI32Type(); Value byteCoord = rewriter.create(loc, VectorType::get(2, i32Type)); - Value zero = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(0)); - Value one = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(1)); + Value zero = rewriter.create(loc, i32Type, 0); + Value one = rewriter.create(loc, i32Type, 1); byteCoord = rewriter.create( loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); byteCoord = rewriter.create( @@ -589,7 +538,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern { } LLVM::CallOp call = createDeviceFunctionCall( rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()), - argTypes, args, paramAttrs, funcAttr); + argTypes, args, paramAttrs, funcAttr, op.getOperation()); if (std::optional optCacheControls = getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) { call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir index 2d9a9d5683756..aeb9e56035653 100644 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -4,8 +4,11 @@ // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s -// CHECK: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -> vector<8xi16> { +// CHECK: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 @@ -14,15 +17,27 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32 // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, linkage = #llvm.linkage, no_unwind, sym_name = "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, will_return} : (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, + // CHECK-SAME: will_return} + // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16> - %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, + pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> llvm.return %loaded_a : vector<8xi16> } // ----- -// CHECK: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) { +// CHECK: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) { llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 @@ -32,31 +47,60 @@ llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i3 // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr // CHECK: llvm.store %[[ARG6]], %[[VAR6]] : vector<8xi32>, !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, linkage = #llvm.linkage, no_unwind, sym_name = "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, will_return} : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) -> () - xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, + // CHECK-SAME: will_return} + // CHECK-SAME: : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) -> () + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted + <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}> + : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) llvm.return } // ----- -// CHECK: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects, no_unwind} -// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) { +// CHECK: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes +// CHECK-SAME: {memory_effects = #llvm.memory_effects, no_unwind} +// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) { llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) { // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> - // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) {function_type = !llvm.func, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage, memory_effects = #llvm.memory_effects, no_unwind, sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64 - xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, no_unwind, + // CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64 + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y + <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, + cache_control=#xevm.load_cache_control}> + : (!llvm.ptr<1>, i32, i32, i32, i32, i32) llvm.return } // ----- -// CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} +// CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( +// CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes +// CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} // CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> { llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { - // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type = !llvm.func (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage, memory_effects = #llvm.memory_effects, no_unwind, sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> - %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted { shape=, types= } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( + // CHECK-SAME: %[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type = + // CHECK-SAME: !llvm.func (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, no_unwind, + // CHECK-SAME: sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return} + // CHECK-SAME: : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted + { shape=, types= } + : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> llvm.return %c_result : vector<8xf32> } @@ -66,17 +110,23 @@ llvm.func @memfence() { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]]) {function_type = !llvm.func, linkage = #llvm.linkage, no_unwind, sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> () + // CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]]) + // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, no_unwind, + // CHECK-SAME: sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> () xevm.memfence <{addrspace=#xevm.addr_space, scope=#xevm.mem_scope}> llvm.return } // ----- -// CHECK: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes {memory_effects = #llvm.memory_effects, no_unwind} +// CHECK: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes +// CHECK-SAME: {memory_effects = #llvm.memory_effects, no_unwind} // CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) { llvm.func @prefetch(%ptr: !llvm.ptr<1>) { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64 - // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]]) {function_type = !llvm.func, i64)>, linkage = #llvm.linkage, memory_effects = #llvm.memory_effects, no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64 + // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]]) + // CHECK-SAME: {function_type = !llvm.func, i64)>, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, + // CHECK-SAME: no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64 xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) llvm.return } From c79fcaef938357dd6ed8685a2578ce77f366b7be Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 9 Jul 2025 04:20:53 +0000 Subject: [PATCH 03/11] Address reviewer comment. --- mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 4605ef78ee50d..d93ba7915a38c 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -240,9 +240,7 @@ static LLVM::CallOp createDeviceFunctionCall( ArrayRef argTypes, ArrayRef args, mlir::ArrayRef> paramAttrs, LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) { - auto moduleOp = rewriter.getBlock() - ->getParentOp() - ->getParentWithTrait(); + auto moduleOp = op->getParentWithTrait(); assert(moduleOp && "Expecting module"); Location loc = op->getLoc(); From 9df3d2d9533fa0aad475289b094a58e066a4d8d3 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 9 Jul 2025 17:58:56 +0000 Subject: [PATCH 04/11] Remove trailing comma. --- mlir/include/mlir/Conversion/Passes.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 78e68632409d8..50c67da91a4af 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1501,7 +1501,7 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> { def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> { let summary = "Convert XeVM to LLVM dialect"; - let dependentDialects = ["LLVM::LLVMDialect", ]; + let dependentDialects = ["LLVM::LLVMDialect"]; } #endif // MLIR_CONVERSION_PASSES From 5baf9fca5022c019d515f88cf0179af21b6a4ea4 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Thu, 10 Jul 2025 14:28:53 +0000 Subject: [PATCH 05/11] Use result type instead of C type for MMA. --- mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index d93ba7915a38c..a1dce9270ec68 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -303,13 +303,14 @@ class MMAToOCLPattern : public OpConversionPattern { Value c = op.getC(); VectorType cOrigTy = cast(c.getType()); - assert(cOrigTy == op->getResultTypes()[0] && - "Accumulator and result type mismatch"); + VectorType resOrigTy = cast(op->getResultTypes()[0]); + assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch"); // OCL builtins encode bfloat16 as int16 VectorType cTy = cOrigTy.getElementType().isBF16() ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16)) : cOrigTy; + VectorType resTy = cTy; if (cOrigTy != cTy) c = rewriter.create(loc, cTy, c); @@ -332,12 +333,12 @@ class MMAToOCLPattern : public OpConversionPattern { auto funcAttrs = convergentNoUnwindWillReturnAttrs; funcAttrs.memEffectsAttr = memAttr; Value result = - createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, args, {}, + createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {}, funcAttrs, op.getOperation()) ->getResult(0); - if (cOrigTy != cTy) - result = rewriter.create(loc, cOrigTy, result); + if (resOrigTy != resTy) + result = rewriter.create(loc, resOrigTy, result); rewriter.replaceOp(op, result); return success(); From 35058586b8047a89d7760984f230fdf564437273 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Thu, 10 Jul 2025 17:35:27 +0000 Subject: [PATCH 06/11] Add v_blocks, transpose and pack_register usage case. --- .../Conversion/XeVMToLLVM/xevm-to-llvm.mlir | 103 ++++++++++++++++-- 1 file changed, 95 insertions(+), 8 deletions(-) diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir index aeb9e56035653..7ad3f920d4d09 100644 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -4,7 +4,7 @@ // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s -// CHECK: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( +// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} // CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, @@ -18,11 +18,11 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32 // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( - // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = // CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, - // CHECK-SAME: will_return} + // CHECK-SAME: will_return} : // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16> @@ -33,7 +33,94 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32 } // ----- -// CHECK: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( +// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<16xi16> { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, + // CHECK-SAME: will_return} + // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<16xi16> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=2 : i32, transpose=false, + pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return %loaded_a : vector<16xi16> +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, + // CHECK-SAME: will_return} : + // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=false, + pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return %loaded_a : vector<8xi32> +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, + // CHECK-SAME: will_return} + // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=32 : i32, tile_width=8 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=true, + pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return %loaded_a : vector<8xi32> +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} // CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, @@ -62,7 +149,7 @@ llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i3 } // ----- -// CHECK: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( +// CHECK-LABEL: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes // CHECK-SAME: {memory_effects = #llvm.memory_effects, no_unwind} // CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, @@ -86,7 +173,7 @@ llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i } // ----- -// CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( +// CHECK-LABEL: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( // CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes // CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} @@ -105,7 +192,7 @@ llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loade } // ----- -// CHECK: llvm.func spir_funccc @_Z22atomic_work_item_fenceiii(i32, i32, i32) attributes {no_unwind} +// CHECK-LABEL: llvm.func spir_funccc @_Z22atomic_work_item_fenceiii(i32, i32, i32) attributes {no_unwind} llvm.func @memfence() { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(1 : i32) : i32 @@ -118,7 +205,7 @@ llvm.func @memfence() { } // ----- -// CHECK: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes +// CHECK-LABEL: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes // CHECK-SAME: {memory_effects = #llvm.memory_effects, no_unwind} // CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) { llvm.func @prefetch(%ptr: !llvm.ptr<1>) { From 4f80a3412fcfadb6bfab4443a611929220ba497c Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Thu, 10 Jul 2025 17:43:17 +0000 Subject: [PATCH 07/11] Remove warning: default label in switch which covers all enumeration values [-Wcovered-switch-default] --- mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index a1dce9270ec68..aac160431b933 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -125,8 +125,6 @@ int32_t getL1CacheControl(OpType op) { case LoadCacheControl::INVALIDATE_READ: control = 4; break; - default: - break; } } else { switch (*op.getCacheControl()) { @@ -153,8 +151,6 @@ int32_t getL1CacheControl(OpType op) { case StoreCacheControl::L1WB_L2UC_L3WB: control = 4; break; - default: - break; } } return control; @@ -184,8 +180,6 @@ int32_t getL3CacheControl(OpType op) { case LoadCacheControl::INVALIDATE_READ: control = 4; break; - default: - break; } } else { switch (*op.getCacheControl()) { @@ -208,8 +202,6 @@ int32_t getL3CacheControl(OpType op) { case StoreCacheControl::L1WB_L2UC_L3WB: control = 2; break; - default: - break; } } return control; From bede0c3f0c16e4350843b680d5be5e31ab5e0f75 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Thu, 10 Jul 2025 20:09:32 +0000 Subject: [PATCH 08/11] Check for operand types supported by lowering. --- mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index aac160431b933..4adff4f4d72ad 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -265,6 +265,28 @@ class MMAToOCLPattern : public OpConversionPattern { if (!op.getC()) { return rewriter.notifyMatchFailure(op, "OCL requires C operand"); } + auto precisionA = op.getTypes().getA(); + auto precisionB = op.getTypes().getB(); + auto precisionC = op.getTypes().getC(); + auto precisionD = op.getTypes().getD(); + if (precisionC != precisionD) { + return rewriter.notifyMatchFailure(op, "type of C and D need to match"); + } + if (precisionC != xevm::ElemType::S32 && + precisionC != xevm::ElemType::F32 && + precisionC != xevm::ElemType::F16 && + precisionC != xevm::ElemType::BF16) { + return rewriter.notifyMatchFailure( + op, "type of C and D must be S32, F32, F16 or BF16"); + } + if (precisionA == xevm::ElemType::S32 || + precisionA == xevm::ElemType::F32) { + return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32"); + } + if (precisionB == xevm::ElemType::S32 || + precisionB == xevm::ElemType::F32) { + return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32"); + } constexpr uint32_t bitWidthPackedA{16}; constexpr uint32_t bitWidthPackedB{32}; auto loc = op.getLoc(); From ef1704bee5aee3243abe1b028268246927fc0eaa Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Thu, 10 Jul 2025 20:26:25 +0000 Subject: [PATCH 09/11] Add test cases with cache control. --- .../Conversion/XeVMToLLVM/xevm-to-llvm.mlir | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir index 7ad3f920d4d09..124a11b6beacb 100644 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -32,6 +32,18 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32 llvm.return %loaded_a : vector<8xi16> } +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( +llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { + // CHECK: xevm.DecorationCacheControl = + // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32 + // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32 + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, + pack_register=false, cache_control=#xevm.load_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return %loaded_a : vector<8xi16> +} + // ----- // CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, @@ -148,6 +160,18 @@ llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i3 llvm.return } +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( +llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { + // CHECK: xevm.DecorationCacheControl = + // CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32 + // CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32 + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted + <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control = #xevm.store_cache_control}> + : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + llvm.return +} + // ----- // CHECK-LABEL: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes From 5e49ab7a2fb7d28e7cdf0e393585ee81fe16a150 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Thu, 10 Jul 2025 21:40:47 +0000 Subject: [PATCH 10/11] Fail gracefully. Report match failure. --- mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 4adff4f4d72ad..bf64308d6a35c 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -423,7 +423,8 @@ class MemfenceToOCLPattern : public OpConversionPattern { break; default: // GENERIC is not supported in OpenCL - llvm_unreachable("Fence only supports global and shared address spaces."); + return rewriter.notifyMatchFailure( + op, "Fence only supports global and shared address spaces."); } switch (op.getScope()) { case xevm::MemScope::WORKGROUP: @@ -434,7 +435,8 @@ class MemfenceToOCLPattern : public OpConversionPattern { break; default: // CLUSTER and SYSTEM are not supported in OpenCL - llvm_unreachable("unsupported xevm::MemoryScope"); + return rewriter.notifyMatchFailure( + op, "Fence only supports workgroup and device memory scopes."); } Type i32Type = rewriter.getI32Type(); Value acqRel = rewriter.create(loc, i32Type, 4); From dd1c9ac5bec94828fdd2fe21f9f53c2e2d5e9874 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Thu, 10 Jul 2025 21:43:38 +0000 Subject: [PATCH 11/11] Give test cases better name. --- .../test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir index 124a11b6beacb..bdbb12bbe0cbb 100644 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -34,7 +34,7 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32 // ----- // CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( -llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { +llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { // CHECK: xevm.DecorationCacheControl = // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32 // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32 @@ -48,9 +48,9 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32 // CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK: llvm.func @blockload2d_v_blocks(%[[ARG0:.*]]: !llvm.ptr<1>, // CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<16xi16> { +llvm.func @blockload2d_v_blocks(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<16xi16> { // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 @@ -77,9 +77,9 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32 // CHECK-LABEL: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK: llvm.func @blockload2d_pack_register(%[[ARG0:.*]]: !llvm.ptr<1>, // CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { +llvm.func @blockload2d_pack_register(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 @@ -106,9 +106,9 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32 // CHECK-LABEL: llvm.func spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK: llvm.func @blockload2d_transpose(%[[ARG0:.*]]: !llvm.ptr<1>, // CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { +llvm.func @blockload2d_transpose(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 @@ -162,7 +162,7 @@ llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i3 // ----- // CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( -llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { +llvm.func @blockstore2d_cache_control(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { // CHECK: xevm.DecorationCacheControl = // CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32 // CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32