diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index c9d2a54433736..8a5976e547169 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -80,6 +80,7 @@ #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h" #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 5a864865adffc..50c67da91a4af 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1495,4 +1495,13 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> { ]; } +//===----------------------------------------------------------------------===// +// XeVMToLLVM +//===----------------------------------------------------------------------===// + +def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> { + let summary = "Convert XeVM to LLVM dialect"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h new file mode 100644 index 0000000000000..7ffdbd4307f9e --- /dev/null +++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h @@ -0,0 +1,27 @@ +//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_ +#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_ + +#include + +namespace mlir { +class DialectRegistry; +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" + +void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns); + +void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry); +} // namespace mlir + +#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_ diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 0f2d0e45008cc..d5a9a2c3aeba7 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -32,6 +32,7 @@ #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" @@ -91,6 +92,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { gpu::registerConvertGpuToLLVMInterface(registry); NVVM::registerConvertGpuToNVVMInterface(registry); vector::registerConvertVectorToLLVMInterface(registry); + registerConvertXeVMToLLVMInterface(registry); // Register all transform dialect extensions. affine::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index e4b4974600577..24a48993ad80c 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -73,3 +73,4 @@ add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) add_subdirectory(VectorToXeGPU) +add_subdirectory(XeVMToLLVM) diff --git a/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt new file mode 100644 index 0000000000000..4ac60d8d43472 --- /dev/null +++ b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(MLIRXeVMToLLVM + XeVMToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRGPUDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRXeVMDialect + MLIRPass + MLIRTransforms +) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp new file mode 100644 index 0000000000000..bf64308d6a35c --- /dev/null +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -0,0 +1,633 @@ +//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/FormatVariadic.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace xevm; + +namespace { + +struct LLVMFuncAttributeOptions { + bool isConvergent = false; + bool isNoUnwind = false; + bool isWillReturn = false; + LLVM::MemoryEffectsAttr memEffectsAttr{}; +}; +static constexpr LLVMFuncAttributeOptions noUnwindAttrs = { + false, true, false, {}}; +static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = { + false, true, true, {}}; +static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = { + true, true, true, {}}; + +std::string getTypeMangling(Type ty, bool isUnsigned = false) { + return TypeSwitch(ty) + .Case([isUnsigned](VectorType ty) -> std::string { + return "Dv" + std::to_string(ty.getNumElements()) + "_" + + getTypeMangling(ty.getElementType(), isUnsigned); + }) + .Case([](Float16Type) -> std::string { return "Dh"; }) + .Case([](Float32Type) -> std::string { return "f"; }) + .Case([](Float64Type) -> std::string { return "d"; }) + .Case([isUnsigned](IntegerType ty) -> std::string { + switch (ty.getWidth()) { + case 8: + return isUnsigned ? "h" : "c"; + case 16: + return isUnsigned ? "t" : "s"; + case 32: + return isUnsigned ? "j" : "i"; + case 64: + return isUnsigned ? "m" : "l"; + default: + llvm_unreachable("unhandled integer type"); + } + }) + .Default([](Type) -> std::string { + llvm_unreachable("unhandled type for mangling"); + }); +} + +std::string mangle(StringRef baseName, ArrayRef types, + ArrayRef isUnsigned = {}) { + assert((isUnsigned.empty() || isUnsigned.size() == types.size()) && + "Signedness info doesn't match"); + std::string s; + llvm::raw_string_ostream os(s); + llvm::SmallDenseMap substitutions; + os << "_Z" << baseName.size() << baseName; + for (auto [idx, type] : llvm::enumerate(types)) { + auto it = substitutions.find(type); + if (it != substitutions.end()) { + os << "S"; + // First substitution is `S_`, second is `S0_`, and so on. + if (unsigned firstIdx = it->getSecond(); firstIdx > 0) + os << firstIdx - 1; + os << "_"; + } else { + if (!type.isIntOrFloat()) + substitutions[type] = substitutions.size(); + os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]); + } + } + return os.str(); +} + +template +int32_t getL1CacheControl(OpType op) { + int32_t control = 0; + if constexpr (isLoad) { + switch (*op.getCacheControl()) { + case LoadCacheControl::L1UC_L2UC_L3UC: + case LoadCacheControl::L1UC_L2UC_L3C: + case LoadCacheControl::L1UC_L2C_L3UC: + case LoadCacheControl::L1UC_L2C_L3C: + control = 1; + break; + case LoadCacheControl::L1C_L2UC_L3UC: + case LoadCacheControl::L1C_L2UC_L3C: + case LoadCacheControl::L1C_L2C_L3UC: + case LoadCacheControl::L1C_L2C_L3C: + control = 2; + break; + case LoadCacheControl::L1S_L2UC_L3UC: + case LoadCacheControl::L1S_L2UC_L3C: + case LoadCacheControl::L1S_L2C_L3UC: + case LoadCacheControl::L1S_L2C_L3C: + control = 3; + break; + case LoadCacheControl::INVALIDATE_READ: + control = 4; + break; + } + } else { + switch (*op.getCacheControl()) { + case StoreCacheControl::L1UC_L2UC_L3UC: + case StoreCacheControl::L1UC_L2UC_L3WB: + case StoreCacheControl::L1UC_L2WB_L3UC: + case StoreCacheControl::L1UC_L2WB_L3WB: + control = 1; + break; + case StoreCacheControl::L1WT_L2UC_L3UC: + case StoreCacheControl::L1WT_L2UC_L3WB: + case StoreCacheControl::L1WT_L2WB_L3UC: + case StoreCacheControl::L1WT_L2WB_L3WB: + control = 2; + break; + case StoreCacheControl::L1S_L2UC_L3UC: + case StoreCacheControl::L1S_L2UC_L3WB: + case StoreCacheControl::L1S_L2WB_L3UC: + case StoreCacheControl::L1S_L2WB_L3WB: + control = 3; + break; + case StoreCacheControl::L1WB_L2UC_L3UC: + case StoreCacheControl::L1WB_L2WB_L3UC: + case StoreCacheControl::L1WB_L2UC_L3WB: + control = 4; + break; + } + } + return control; +} + +template +int32_t getL3CacheControl(OpType op) { + int32_t control = 0; + if constexpr (isLoad) { + switch (*op.getCacheControl()) { + case LoadCacheControl::L1UC_L2UC_L3UC: + case LoadCacheControl::L1UC_L2C_L3UC: + case LoadCacheControl::L1C_L2UC_L3UC: + case LoadCacheControl::L1C_L2C_L3UC: + case LoadCacheControl::L1S_L2UC_L3UC: + case LoadCacheControl::L1S_L2C_L3UC: + control = 1; + break; + case LoadCacheControl::L1UC_L2UC_L3C: + case LoadCacheControl::L1UC_L2C_L3C: + case LoadCacheControl::L1C_L2UC_L3C: + case LoadCacheControl::L1C_L2C_L3C: + case LoadCacheControl::L1S_L2UC_L3C: + case LoadCacheControl::L1S_L2C_L3C: + control = 2; + break; + case LoadCacheControl::INVALIDATE_READ: + control = 4; + break; + } + } else { + switch (*op.getCacheControl()) { + case StoreCacheControl::L1UC_L2UC_L3UC: + case StoreCacheControl::L1UC_L2WB_L3UC: + case StoreCacheControl::L1WT_L2UC_L3UC: + case StoreCacheControl::L1WT_L2WB_L3UC: + case StoreCacheControl::L1S_L2UC_L3UC: + case StoreCacheControl::L1S_L2WB_L3UC: + case StoreCacheControl::L1WB_L2UC_L3UC: + case StoreCacheControl::L1WB_L2WB_L3UC: + control = 1; + break; + case StoreCacheControl::L1UC_L2UC_L3WB: + case StoreCacheControl::L1UC_L2WB_L3WB: + case StoreCacheControl::L1WT_L2UC_L3WB: + case StoreCacheControl::L1WT_L2WB_L3WB: + case StoreCacheControl::L1S_L2UC_L3WB: + case StoreCacheControl::L1S_L2WB_L3WB: + case StoreCacheControl::L1WB_L2UC_L3WB: + control = 2; + break; + } + } + return control; +} + +template +static std::optional +getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { + if (!op.getCacheControl()) + return {}; + constexpr int32_t decorationCacheControlArity{4}; + constexpr int32_t loadCacheControlKey{6442}; + constexpr int32_t storeCacheControlKey{6443}; + const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; + SmallVector decorationsL1{ + controlKey, 0, getL1CacheControl(op), 0}; + SmallVector decorationsL3{ + controlKey, 1, getL3CacheControl(op), 0}; + auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1); + auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3); + + SmallVector combinedAttrs = {arrayAttrL1, arrayAttrL3}; + return rewriter.getArrayAttr(combinedAttrs); +} + +static LLVM::CallOp createDeviceFunctionCall( + ConversionPatternRewriter &rewriter, StringRef funcName, Type retType, + ArrayRef argTypes, ArrayRef args, + mlir::ArrayRef> paramAttrs, + LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) { + auto moduleOp = op->getParentWithTrait(); + assert(moduleOp && "Expecting module"); + Location loc = op->getLoc(); + + auto funcOpRes = + LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType); + assert(!failed(funcOpRes)); + LLVM::LLVMFuncOp funcOp = funcOpRes.value(); + funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + funcOp.setConvergent(funcAttributeOptions.isConvergent); + funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind); + funcOp.setWillReturn(funcAttributeOptions.isWillReturn); + + if (funcAttributeOptions.memEffectsAttr) + funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr); + + for (auto [idx, attrName] : paramAttrs) + funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr()); + + auto callOp = rewriter.create(loc, funcOp, args); + callOp->setAttrs(funcOp->getAttrs()); + + return callOp; +} + +class MMAToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getC()) { + return rewriter.notifyMatchFailure(op, "OCL requires C operand"); + } + auto precisionA = op.getTypes().getA(); + auto precisionB = op.getTypes().getB(); + auto precisionC = op.getTypes().getC(); + auto precisionD = op.getTypes().getD(); + if (precisionC != precisionD) { + return rewriter.notifyMatchFailure(op, "type of C and D need to match"); + } + if (precisionC != xevm::ElemType::S32 && + precisionC != xevm::ElemType::F32 && + precisionC != xevm::ElemType::F16 && + precisionC != xevm::ElemType::BF16) { + return rewriter.notifyMatchFailure( + op, "type of C and D must be S32, F32, F16 or BF16"); + } + if (precisionA == xevm::ElemType::S32 || + precisionA == xevm::ElemType::F32) { + return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32"); + } + if (precisionB == xevm::ElemType::S32 || + precisionB == xevm::ElemType::F32) { + return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32"); + } + constexpr uint32_t bitWidthPackedA{16}; + constexpr uint32_t bitWidthPackedB{32}; + auto loc = op.getLoc(); + + auto castIfNeeded = [&](Value val, Type packedType) -> Value { + VectorType origTy = cast(val.getType()); + const uint32_t vecBitSize = + origTy.getNumElements() * + origTy.getElementType().getIntOrFloatBitWidth(); + VectorType newTy = VectorType::get( + vecBitSize / packedType.getIntOrFloatBitWidth(), packedType); + if (origTy != newTy) + val = rewriter.create(loc, newTy, val); + return val; + }; + + Value a = op.getA(); + Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32) + ? cast(rewriter.getF32Type()) + : rewriter.getIntegerType(bitWidthPackedA); + a = castIfNeeded(a, packedAType); + + Value b = op.getB(); + Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32) + ? cast(rewriter.getF32Type()) + : rewriter.getIntegerType(bitWidthPackedB); + b = castIfNeeded(b, packedBType); + + Value c = op.getC(); + VectorType cOrigTy = cast(c.getType()); + VectorType resOrigTy = cast(op->getResultTypes()[0]); + assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch"); + // OCL builtins encode bfloat16 as int16 + VectorType cTy = + cOrigTy.getElementType().isBF16() + ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16)) + : cOrigTy; + VectorType resTy = cTy; + if (cOrigTy != cTy) + c = rewriter.create(loc, cTy, c); + + constexpr int32_t systolicDepth{8}; + std::string fnName = + llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}", + stringifyElemType(op.getTypes().getA()).str(), + stringifyElemType(op.getTypes().getB()).str(), + systolicDepth * + getNumOperandsPerDword(op.getTypes().getA())) + .str(); + SmallVector argTypes{a.getType(), b.getType(), cTy}; + fnName = mangle(fnName, argTypes); + SmallVector args{a, b, c}; + + auto memAttr = rewriter.getAttr( + /*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::NoModRef, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + auto funcAttrs = convergentNoUnwindWillReturnAttrs; + funcAttrs.memEffectsAttr = memAttr; + Value result = + createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {}, + funcAttrs, op.getOperation()) + ->getResult(0); + + if (resOrigTy != resTy) + result = rewriter.create(loc, resOrigTy, result); + + rewriter.replaceOp(op, result); + return success(); + } + +private: + static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { + switch (pTy) { + case xevm::ElemType::TF32: + return 1; + case xevm::ElemType::BF16: + case xevm::ElemType::F16: + return 2; + case xevm::ElemType::U8: + case xevm::ElemType::S8: + return 4; + default: + llvm_unreachable("unsupported xevm::ElemType"); + } + } +}; + +class PrefetchToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const std::string fnName{"_Z8prefetchPU3AS1Kcm"}; + Value one = + rewriter.create(loc, rewriter.getI64Type(), 1); + SmallVector args{op.getPtr(), one}; + SmallVector argTypes; + for (auto arg : args) + argTypes.push_back(arg.getType()); + auto funcAttr = noUnwindAttrs; + auto memAttr = rewriter.getAttr( + /*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::Ref, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + funcAttr.memEffectsAttr = memAttr; + + LLVM::CallOp call = createDeviceFunctionCall( + rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()), + argTypes, args, {}, funcAttr, op.getOperation()); + if (std::optional optCacheControls = + getCacheControlMetadata(rewriter, op)) + call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); + rewriter.eraseOp(op); + return success(); + } +}; + +class MemfenceToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const std::string fnName{"atomic_work_item_fence"}; + int memScope, addrSpace; + switch (op.getAddrspace()) { + case xevm::AddrSpace::SHARED: + addrSpace = 1; // CLK_LOCAL_MEM_FENCE + break; + case xevm::AddrSpace::GLOBAL: + addrSpace = 2; // CLK_GLOBAL_MEM_FENCE + break; + default: + // GENERIC is not supported in OpenCL + return rewriter.notifyMatchFailure( + op, "Fence only supports global and shared address spaces."); + } + switch (op.getScope()) { + case xevm::MemScope::WORKGROUP: + memScope = 1; + break; + case xevm::MemScope::DEVICE: + memScope = 2; + break; + default: + // CLUSTER and SYSTEM are not supported in OpenCL + return rewriter.notifyMatchFailure( + op, "Fence only supports workgroup and device memory scopes."); + } + Type i32Type = rewriter.getI32Type(); + Value acqRel = rewriter.create(loc, i32Type, 4); + Value memScopeConst = + rewriter.create(loc, i32Type, memScope); + Value addrSpaceConst = + rewriter.create(loc, i32Type, addrSpace); + SmallVector args{addrSpaceConst, acqRel, memScopeConst}; + SmallVector argTypes{3, i32Type}; + createDeviceFunctionCall(rewriter, mangle(fnName, argTypes), + LLVM::LLVMVoidType::get(rewriter.getContext()), + argTypes, args, {}, noUnwindAttrs, + op.getOperation()); + rewriter.eraseOp(op); + return success(); + } +}; +template +class LoadStorePrefetchToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr bool isLoad = std::is_same_v; + constexpr bool isPrefetch = std::is_same_v; + + auto loc = op.getLoc(); + VectorType vecType; + bool packReg = false; + bool transpose = false; + if constexpr (isLoad) { + vecType = op.getRes().getType(); + packReg = op.getPackRegister(); + transpose = op.getTranspose(); + } else if constexpr (!isPrefetch) { + vecType = op.getStoredVal().getType(); + } + + auto i32Type = rewriter.getI32Type(); + Value byteCoord = + rewriter.create(loc, VectorType::get(2, i32Type)); + Value zero = rewriter.create(loc, i32Type, 0); + Value one = rewriter.create(loc, i32Type, 1); + byteCoord = rewriter.create( + loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); + byteCoord = rewriter.create( + loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); + SmallVector args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(), + op.getBasePitch(), byteCoord}; + SmallVector retTypes; + Value spvLoadDstPtr; + std::string funcName{"intel_sub_group_2d_block_"}; + std::string bitWidthId; + LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs}; + SmallVector, 4> paramAttrs; + if constexpr (isPrefetch) { // Prefetch + funcName += "prefetch"; + paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())}; + auto memAttr = rewriter.getAttr( + /*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::Ref, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + funcAttr = noUnwindAttrs; + funcAttr.memEffectsAttr = memAttr; + } else { + auto vecElemType = vecType.getElementType(); + auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth(); + Value numElems = rewriter.create( + loc, i32Type, vecType.getNumElements()); + auto dstOrSrcPtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType, + numElems); + args.push_back(dstOrSrcPtr); + if constexpr (isLoad) { // Load + funcName += "read"; + bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true); + if (packReg) + funcName += "_transform"; + else if (transpose) + funcName += "_transpose"; + spvLoadDstPtr = dstOrSrcPtr; + retTypes.push_back(vecType); + paramAttrs = { + std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()), + }; + } else { // Store + funcName += "write"; + bitWidthId = (vecElemBitWidth == 32) + ? "j" + : ((vecElemBitWidth == 16) ? "t" : "h"); + rewriter.create(loc, op.getStoredVal(), dstOrSrcPtr); + paramAttrs = { + std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()), + }; + } + } + + funcName = + llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(), + op.getTileHeight(), op.getTileWidth(), op.getVBlocks()) + .str(); + funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(), + funcName, isPrefetch ? "" : "P", bitWidthId) + .str(); + SmallVector argTypes; + for (auto arg : args) { + argTypes.push_back(arg.getType()); + } + LLVM::CallOp call = createDeviceFunctionCall( + rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()), + argTypes, args, paramAttrs, funcAttr, op.getOperation()); + if (std::optional optCacheControls = + getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) { + call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); + } + if constexpr (isLoad) + rewriter.replaceOp( + op, rewriter.create(loc, vecType, spvLoadDstPtr)); + else + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct ConvertXeVMToLLVMPass + : public impl::ConvertXeVMToLLVMPassBase { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + RewritePatternSet patterns(&getContext()); + populateXeVMToLLVMConversionPatterns(patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert XeVM to LLVM. +struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateXeVMToLLVMConversionPatterns(patterns); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) { + patterns.add, + LoadStorePrefetchToOCLPattern, + LoadStorePrefetchToOCLPattern, + MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>( + patterns.getContext()); +} + +void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir new file mode 100644 index 0000000000000..bdbb12bbe0cbb --- /dev/null +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -0,0 +1,244 @@ +// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s + +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s + +// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, + // CHECK-SAME: will_return} : + // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, + pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return %loaded_a : vector<8xi16> +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( +llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { + // CHECK: xevm.DecorationCacheControl = + // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32 + // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32 + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, + pack_register=false, cache_control=#xevm.load_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return %loaded_a : vector<8xi16> +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockload2d_v_blocks(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +llvm.func @blockload2d_v_blocks(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<16xi16> { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, + // CHECK-SAME: will_return} + // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<16xi16> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=2 : i32, transpose=false, + pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return %loaded_a : vector<16xi16> +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockload2d_pack_register(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +llvm.func @blockload2d_pack_register(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, + // CHECK-SAME: will_return} : + // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=false, + pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return %loaded_a : vector<8xi32> +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockload2d_transpose(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +llvm.func @blockload2d_transpose(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, + // CHECK-SAME: will_return} + // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=32 : i32, tile_width=8 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=true, + pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return %loaded_a : vector<8xi32> +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, +// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) { +llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG6]], %[[VAR6]] : vector<8xi32>, !llvm.ptr + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, + // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = + // CHECK-SAME: "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, + // CHECK-SAME: will_return} + // CHECK-SAME: : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, + // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) -> () + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted + <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}> + : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + llvm.return +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( +llvm.func @blockstore2d_cache_control(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { + // CHECK: xevm.DecorationCacheControl = + // CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32 + // CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32 + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted + <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control = #xevm.store_cache_control}> + : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + llvm.return +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( +// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes +// CHECK-SAME: {memory_effects = #llvm.memory_effects, no_unwind} +// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) { +llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) { + // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> + // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> + // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) + // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, no_unwind, + // CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64 + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y + <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, + cache_control=#xevm.load_cache_control}> + : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( +// CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes +// CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} +// CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> { +llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { + // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( + // CHECK-SAME: %[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type = + // CHECK-SAME: !llvm.func (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, no_unwind, + // CHECK-SAME: sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return} + // CHECK-SAME: : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted + { shape=, types= } + : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + llvm.return %c_result : vector<8xf32> +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z22atomic_work_item_fenceiii(i32, i32, i32) attributes {no_unwind} +llvm.func @memfence() { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]]) + // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, no_unwind, + // CHECK-SAME: sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> () + xevm.memfence <{addrspace=#xevm.addr_space, scope=#xevm.mem_scope}> + llvm.return +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes +// CHECK-SAME: {memory_effects = #llvm.memory_effects, no_unwind} +// CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) { +llvm.func @prefetch(%ptr: !llvm.ptr<1>) { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]]) + // CHECK-SAME: {function_type = !llvm.func, i64)>, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, + // CHECK-SAME: no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64 + xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + llvm.return +} +