diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 8a5976e547169..c9d2a54433736 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -80,7 +80,6 @@ #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h" #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" -#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 50c67da91a4af..5a864865adffc 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1495,13 +1495,4 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> { ]; } -//===----------------------------------------------------------------------===// -// XeVMToLLVM -//===----------------------------------------------------------------------===// - -def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> { - let summary = "Convert XeVM to LLVM dialect"; - let dependentDialects = ["LLVM::LLVMDialect"]; -} - #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h deleted file mode 100644 index 7ffdbd4307f9e..0000000000000 --- a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h +++ /dev/null @@ -1,27 +0,0 @@ -//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_ -#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_ - -#include - -namespace mlir { -class DialectRegistry; -class LLVMTypeConverter; -class RewritePatternSet; -class Pass; - -#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS -#include "mlir/Conversion/Passes.h.inc" - -void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns); - -void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry); -} // namespace mlir - -#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_ diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index d5a9a2c3aeba7..0f2d0e45008cc 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -32,7 +32,6 @@ #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" @@ -92,7 +91,6 @@ inline void registerAllExtensions(DialectRegistry ®istry) { gpu::registerConvertGpuToLLVMInterface(registry); NVVM::registerConvertGpuToNVVMInterface(registry); vector::registerConvertVectorToLLVMInterface(registry); - registerConvertXeVMToLLVMInterface(registry); // Register all transform dialect extensions. affine::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 24a48993ad80c..e4b4974600577 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -73,4 +73,3 @@ add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) add_subdirectory(VectorToXeGPU) -add_subdirectory(XeVMToLLVM) diff --git a/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt deleted file mode 100644 index 4ac60d8d43472..0000000000000 --- a/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -add_mlir_conversion_library(MLIRXeVMToLLVM - XeVMToLLVM.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM - - DEPENDS - MLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRFuncDialect - MLIRGPUDialect - MLIRLLVMCommonConversion - MLIRLLVMDialect - MLIRXeVMDialect - MLIRPass - MLIRTransforms -) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp deleted file mode 100644 index bf64308d6a35c..0000000000000 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ /dev/null @@ -1,633 +0,0 @@ -//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" - -#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/XeVMDialect.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "llvm/Support/FormatVariadic.h" - -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Types.h" - -#include "llvm/ADT/TypeSwitch.h" - -namespace mlir { -#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS -#include "mlir/Conversion/Passes.h.inc" -} // namespace mlir - -using namespace mlir; -using namespace xevm; - -namespace { - -struct LLVMFuncAttributeOptions { - bool isConvergent = false; - bool isNoUnwind = false; - bool isWillReturn = false; - LLVM::MemoryEffectsAttr memEffectsAttr{}; -}; -static constexpr LLVMFuncAttributeOptions noUnwindAttrs = { - false, true, false, {}}; -static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = { - false, true, true, {}}; -static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = { - true, true, true, {}}; - -std::string getTypeMangling(Type ty, bool isUnsigned = false) { - return TypeSwitch(ty) - .Case([isUnsigned](VectorType ty) -> std::string { - return "Dv" + std::to_string(ty.getNumElements()) + "_" + - getTypeMangling(ty.getElementType(), isUnsigned); - }) - .Case([](Float16Type) -> std::string { return "Dh"; }) - .Case([](Float32Type) -> std::string { return "f"; }) - .Case([](Float64Type) -> std::string { return "d"; }) - .Case([isUnsigned](IntegerType ty) -> std::string { - switch (ty.getWidth()) { - case 8: - return isUnsigned ? "h" : "c"; - case 16: - return isUnsigned ? "t" : "s"; - case 32: - return isUnsigned ? "j" : "i"; - case 64: - return isUnsigned ? "m" : "l"; - default: - llvm_unreachable("unhandled integer type"); - } - }) - .Default([](Type) -> std::string { - llvm_unreachable("unhandled type for mangling"); - }); -} - -std::string mangle(StringRef baseName, ArrayRef types, - ArrayRef isUnsigned = {}) { - assert((isUnsigned.empty() || isUnsigned.size() == types.size()) && - "Signedness info doesn't match"); - std::string s; - llvm::raw_string_ostream os(s); - llvm::SmallDenseMap substitutions; - os << "_Z" << baseName.size() << baseName; - for (auto [idx, type] : llvm::enumerate(types)) { - auto it = substitutions.find(type); - if (it != substitutions.end()) { - os << "S"; - // First substitution is `S_`, second is `S0_`, and so on. - if (unsigned firstIdx = it->getSecond(); firstIdx > 0) - os << firstIdx - 1; - os << "_"; - } else { - if (!type.isIntOrFloat()) - substitutions[type] = substitutions.size(); - os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]); - } - } - return os.str(); -} - -template -int32_t getL1CacheControl(OpType op) { - int32_t control = 0; - if constexpr (isLoad) { - switch (*op.getCacheControl()) { - case LoadCacheControl::L1UC_L2UC_L3UC: - case LoadCacheControl::L1UC_L2UC_L3C: - case LoadCacheControl::L1UC_L2C_L3UC: - case LoadCacheControl::L1UC_L2C_L3C: - control = 1; - break; - case LoadCacheControl::L1C_L2UC_L3UC: - case LoadCacheControl::L1C_L2UC_L3C: - case LoadCacheControl::L1C_L2C_L3UC: - case LoadCacheControl::L1C_L2C_L3C: - control = 2; - break; - case LoadCacheControl::L1S_L2UC_L3UC: - case LoadCacheControl::L1S_L2UC_L3C: - case LoadCacheControl::L1S_L2C_L3UC: - case LoadCacheControl::L1S_L2C_L3C: - control = 3; - break; - case LoadCacheControl::INVALIDATE_READ: - control = 4; - break; - } - } else { - switch (*op.getCacheControl()) { - case StoreCacheControl::L1UC_L2UC_L3UC: - case StoreCacheControl::L1UC_L2UC_L3WB: - case StoreCacheControl::L1UC_L2WB_L3UC: - case StoreCacheControl::L1UC_L2WB_L3WB: - control = 1; - break; - case StoreCacheControl::L1WT_L2UC_L3UC: - case StoreCacheControl::L1WT_L2UC_L3WB: - case StoreCacheControl::L1WT_L2WB_L3UC: - case StoreCacheControl::L1WT_L2WB_L3WB: - control = 2; - break; - case StoreCacheControl::L1S_L2UC_L3UC: - case StoreCacheControl::L1S_L2UC_L3WB: - case StoreCacheControl::L1S_L2WB_L3UC: - case StoreCacheControl::L1S_L2WB_L3WB: - control = 3; - break; - case StoreCacheControl::L1WB_L2UC_L3UC: - case StoreCacheControl::L1WB_L2WB_L3UC: - case StoreCacheControl::L1WB_L2UC_L3WB: - control = 4; - break; - } - } - return control; -} - -template -int32_t getL3CacheControl(OpType op) { - int32_t control = 0; - if constexpr (isLoad) { - switch (*op.getCacheControl()) { - case LoadCacheControl::L1UC_L2UC_L3UC: - case LoadCacheControl::L1UC_L2C_L3UC: - case LoadCacheControl::L1C_L2UC_L3UC: - case LoadCacheControl::L1C_L2C_L3UC: - case LoadCacheControl::L1S_L2UC_L3UC: - case LoadCacheControl::L1S_L2C_L3UC: - control = 1; - break; - case LoadCacheControl::L1UC_L2UC_L3C: - case LoadCacheControl::L1UC_L2C_L3C: - case LoadCacheControl::L1C_L2UC_L3C: - case LoadCacheControl::L1C_L2C_L3C: - case LoadCacheControl::L1S_L2UC_L3C: - case LoadCacheControl::L1S_L2C_L3C: - control = 2; - break; - case LoadCacheControl::INVALIDATE_READ: - control = 4; - break; - } - } else { - switch (*op.getCacheControl()) { - case StoreCacheControl::L1UC_L2UC_L3UC: - case StoreCacheControl::L1UC_L2WB_L3UC: - case StoreCacheControl::L1WT_L2UC_L3UC: - case StoreCacheControl::L1WT_L2WB_L3UC: - case StoreCacheControl::L1S_L2UC_L3UC: - case StoreCacheControl::L1S_L2WB_L3UC: - case StoreCacheControl::L1WB_L2UC_L3UC: - case StoreCacheControl::L1WB_L2WB_L3UC: - control = 1; - break; - case StoreCacheControl::L1UC_L2UC_L3WB: - case StoreCacheControl::L1UC_L2WB_L3WB: - case StoreCacheControl::L1WT_L2UC_L3WB: - case StoreCacheControl::L1WT_L2WB_L3WB: - case StoreCacheControl::L1S_L2UC_L3WB: - case StoreCacheControl::L1S_L2WB_L3WB: - case StoreCacheControl::L1WB_L2UC_L3WB: - control = 2; - break; - } - } - return control; -} - -template -static std::optional -getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { - if (!op.getCacheControl()) - return {}; - constexpr int32_t decorationCacheControlArity{4}; - constexpr int32_t loadCacheControlKey{6442}; - constexpr int32_t storeCacheControlKey{6443}; - const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; - SmallVector decorationsL1{ - controlKey, 0, getL1CacheControl(op), 0}; - SmallVector decorationsL3{ - controlKey, 1, getL3CacheControl(op), 0}; - auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1); - auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3); - - SmallVector combinedAttrs = {arrayAttrL1, arrayAttrL3}; - return rewriter.getArrayAttr(combinedAttrs); -} - -static LLVM::CallOp createDeviceFunctionCall( - ConversionPatternRewriter &rewriter, StringRef funcName, Type retType, - ArrayRef argTypes, ArrayRef args, - mlir::ArrayRef> paramAttrs, - LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) { - auto moduleOp = op->getParentWithTrait(); - assert(moduleOp && "Expecting module"); - Location loc = op->getLoc(); - - auto funcOpRes = - LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType); - assert(!failed(funcOpRes)); - LLVM::LLVMFuncOp funcOp = funcOpRes.value(); - funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); - funcOp.setConvergent(funcAttributeOptions.isConvergent); - funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind); - funcOp.setWillReturn(funcAttributeOptions.isWillReturn); - - if (funcAttributeOptions.memEffectsAttr) - funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr); - - for (auto [idx, attrName] : paramAttrs) - funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr()); - - auto callOp = rewriter.create(loc, funcOp, args); - callOp->setAttrs(funcOp->getAttrs()); - - return callOp; -} - -class MMAToOCLPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getC()) { - return rewriter.notifyMatchFailure(op, "OCL requires C operand"); - } - auto precisionA = op.getTypes().getA(); - auto precisionB = op.getTypes().getB(); - auto precisionC = op.getTypes().getC(); - auto precisionD = op.getTypes().getD(); - if (precisionC != precisionD) { - return rewriter.notifyMatchFailure(op, "type of C and D need to match"); - } - if (precisionC != xevm::ElemType::S32 && - precisionC != xevm::ElemType::F32 && - precisionC != xevm::ElemType::F16 && - precisionC != xevm::ElemType::BF16) { - return rewriter.notifyMatchFailure( - op, "type of C and D must be S32, F32, F16 or BF16"); - } - if (precisionA == xevm::ElemType::S32 || - precisionA == xevm::ElemType::F32) { - return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32"); - } - if (precisionB == xevm::ElemType::S32 || - precisionB == xevm::ElemType::F32) { - return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32"); - } - constexpr uint32_t bitWidthPackedA{16}; - constexpr uint32_t bitWidthPackedB{32}; - auto loc = op.getLoc(); - - auto castIfNeeded = [&](Value val, Type packedType) -> Value { - VectorType origTy = cast(val.getType()); - const uint32_t vecBitSize = - origTy.getNumElements() * - origTy.getElementType().getIntOrFloatBitWidth(); - VectorType newTy = VectorType::get( - vecBitSize / packedType.getIntOrFloatBitWidth(), packedType); - if (origTy != newTy) - val = rewriter.create(loc, newTy, val); - return val; - }; - - Value a = op.getA(); - Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32) - ? cast(rewriter.getF32Type()) - : rewriter.getIntegerType(bitWidthPackedA); - a = castIfNeeded(a, packedAType); - - Value b = op.getB(); - Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32) - ? cast(rewriter.getF32Type()) - : rewriter.getIntegerType(bitWidthPackedB); - b = castIfNeeded(b, packedBType); - - Value c = op.getC(); - VectorType cOrigTy = cast(c.getType()); - VectorType resOrigTy = cast(op->getResultTypes()[0]); - assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch"); - // OCL builtins encode bfloat16 as int16 - VectorType cTy = - cOrigTy.getElementType().isBF16() - ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16)) - : cOrigTy; - VectorType resTy = cTy; - if (cOrigTy != cTy) - c = rewriter.create(loc, cTy, c); - - constexpr int32_t systolicDepth{8}; - std::string fnName = - llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}", - stringifyElemType(op.getTypes().getA()).str(), - stringifyElemType(op.getTypes().getB()).str(), - systolicDepth * - getNumOperandsPerDword(op.getTypes().getA())) - .str(); - SmallVector argTypes{a.getType(), b.getType(), cTy}; - fnName = mangle(fnName, argTypes); - SmallVector args{a, b, c}; - - auto memAttr = rewriter.getAttr( - /*other=*/LLVM::ModRefInfo::NoModRef, - /*argMem=*/LLVM::ModRefInfo::NoModRef, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); - auto funcAttrs = convergentNoUnwindWillReturnAttrs; - funcAttrs.memEffectsAttr = memAttr; - Value result = - createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {}, - funcAttrs, op.getOperation()) - ->getResult(0); - - if (resOrigTy != resTy) - result = rewriter.create(loc, resOrigTy, result); - - rewriter.replaceOp(op, result); - return success(); - } - -private: - static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { - switch (pTy) { - case xevm::ElemType::TF32: - return 1; - case xevm::ElemType::BF16: - case xevm::ElemType::F16: - return 2; - case xevm::ElemType::U8: - case xevm::ElemType::S8: - return 4; - default: - llvm_unreachable("unsupported xevm::ElemType"); - } - } -}; - -class PrefetchToOCLPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - const std::string fnName{"_Z8prefetchPU3AS1Kcm"}; - Value one = - rewriter.create(loc, rewriter.getI64Type(), 1); - SmallVector args{op.getPtr(), one}; - SmallVector argTypes; - for (auto arg : args) - argTypes.push_back(arg.getType()); - auto funcAttr = noUnwindAttrs; - auto memAttr = rewriter.getAttr( - /*other=*/LLVM::ModRefInfo::NoModRef, - /*argMem=*/LLVM::ModRefInfo::Ref, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); - funcAttr.memEffectsAttr = memAttr; - - LLVM::CallOp call = createDeviceFunctionCall( - rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()), - argTypes, args, {}, funcAttr, op.getOperation()); - if (std::optional optCacheControls = - getCacheControlMetadata(rewriter, op)) - call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); - rewriter.eraseOp(op); - return success(); - } -}; - -class MemfenceToOCLPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - const std::string fnName{"atomic_work_item_fence"}; - int memScope, addrSpace; - switch (op.getAddrspace()) { - case xevm::AddrSpace::SHARED: - addrSpace = 1; // CLK_LOCAL_MEM_FENCE - break; - case xevm::AddrSpace::GLOBAL: - addrSpace = 2; // CLK_GLOBAL_MEM_FENCE - break; - default: - // GENERIC is not supported in OpenCL - return rewriter.notifyMatchFailure( - op, "Fence only supports global and shared address spaces."); - } - switch (op.getScope()) { - case xevm::MemScope::WORKGROUP: - memScope = 1; - break; - case xevm::MemScope::DEVICE: - memScope = 2; - break; - default: - // CLUSTER and SYSTEM are not supported in OpenCL - return rewriter.notifyMatchFailure( - op, "Fence only supports workgroup and device memory scopes."); - } - Type i32Type = rewriter.getI32Type(); - Value acqRel = rewriter.create(loc, i32Type, 4); - Value memScopeConst = - rewriter.create(loc, i32Type, memScope); - Value addrSpaceConst = - rewriter.create(loc, i32Type, addrSpace); - SmallVector args{addrSpaceConst, acqRel, memScopeConst}; - SmallVector argTypes{3, i32Type}; - createDeviceFunctionCall(rewriter, mangle(fnName, argTypes), - LLVM::LLVMVoidType::get(rewriter.getContext()), - argTypes, args, {}, noUnwindAttrs, - op.getOperation()); - rewriter.eraseOp(op); - return success(); - } -}; -template -class LoadStorePrefetchToOCLPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - constexpr bool isLoad = std::is_same_v; - constexpr bool isPrefetch = std::is_same_v; - - auto loc = op.getLoc(); - VectorType vecType; - bool packReg = false; - bool transpose = false; - if constexpr (isLoad) { - vecType = op.getRes().getType(); - packReg = op.getPackRegister(); - transpose = op.getTranspose(); - } else if constexpr (!isPrefetch) { - vecType = op.getStoredVal().getType(); - } - - auto i32Type = rewriter.getI32Type(); - Value byteCoord = - rewriter.create(loc, VectorType::get(2, i32Type)); - Value zero = rewriter.create(loc, i32Type, 0); - Value one = rewriter.create(loc, i32Type, 1); - byteCoord = rewriter.create( - loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); - byteCoord = rewriter.create( - loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); - SmallVector args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(), - op.getBasePitch(), byteCoord}; - SmallVector retTypes; - Value spvLoadDstPtr; - std::string funcName{"intel_sub_group_2d_block_"}; - std::string bitWidthId; - LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs}; - SmallVector, 4> paramAttrs; - if constexpr (isPrefetch) { // Prefetch - funcName += "prefetch"; - paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())}; - auto memAttr = rewriter.getAttr( - /*other=*/LLVM::ModRefInfo::NoModRef, - /*argMem=*/LLVM::ModRefInfo::Ref, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); - funcAttr = noUnwindAttrs; - funcAttr.memEffectsAttr = memAttr; - } else { - auto vecElemType = vecType.getElementType(); - auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth(); - Value numElems = rewriter.create( - loc, i32Type, vecType.getNumElements()); - auto dstOrSrcPtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType, - numElems); - args.push_back(dstOrSrcPtr); - if constexpr (isLoad) { // Load - funcName += "read"; - bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true); - if (packReg) - funcName += "_transform"; - else if (transpose) - funcName += "_transpose"; - spvLoadDstPtr = dstOrSrcPtr; - retTypes.push_back(vecType); - paramAttrs = { - std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), - std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()), - std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), - std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()), - }; - } else { // Store - funcName += "write"; - bitWidthId = (vecElemBitWidth == 32) - ? "j" - : ((vecElemBitWidth == 16) ? "t" : "h"); - rewriter.create(loc, op.getStoredVal(), dstOrSrcPtr); - paramAttrs = { - std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), - std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()), - std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), - std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()), - }; - } - } - - funcName = - llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(), - op.getTileHeight(), op.getTileWidth(), op.getVBlocks()) - .str(); - funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(), - funcName, isPrefetch ? "" : "P", bitWidthId) - .str(); - SmallVector argTypes; - for (auto arg : args) { - argTypes.push_back(arg.getType()); - } - LLVM::CallOp call = createDeviceFunctionCall( - rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()), - argTypes, args, paramAttrs, funcAttr, op.getOperation()); - if (std::optional optCacheControls = - getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) { - call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); - } - if constexpr (isLoad) - rewriter.replaceOp( - op, rewriter.create(loc, vecType, spvLoadDstPtr)); - else - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Pass Definition -//===----------------------------------------------------------------------===// - -struct ConvertXeVMToLLVMPass - : public impl::ConvertXeVMToLLVMPassBase { - using Base::Base; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addIllegalDialect(); - RewritePatternSet patterns(&getContext()); - populateXeVMToLLVMConversionPatterns(patterns); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - signalPassFailure(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// ConvertToLLVMPatternInterface implementation -//===----------------------------------------------------------------------===// - -namespace { -/// Implement the interface to convert XeVM to LLVM. -struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { - using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; - void loadDependentDialects(MLIRContext *context) const final { - context->loadDialect(); - } - - /// Hook for derived dialect interface to provide conversion patterns - /// and mark dialect legal for the conversion target. - void populateConvertToLLVMConversionPatterns( - ConversionTarget &target, LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns) const final { - populateXeVMToLLVMConversionPatterns(patterns); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Pattern Population -//===----------------------------------------------------------------------===// - -void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) { - patterns.add, - LoadStorePrefetchToOCLPattern, - LoadStorePrefetchToOCLPattern, - MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>( - patterns.getContext()); -} - -void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) { - dialect->addInterfaces(); - }); -} diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir deleted file mode 100644 index bdbb12bbe0cbb..0000000000000 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ /dev/null @@ -1,244 +0,0 @@ -// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s - -// Same below, but using the `ConvertToLLVMPatternInterface` entry point -// and the generic `convert-to-llvm` pass. -// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s - -// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( -// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, -// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, -// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { - // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> - // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> - // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> - // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( - // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) - // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, - // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = - // CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, - // CHECK-SAME: will_return} : - // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, - // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () - // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16> - %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y - <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, - pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> - llvm.return %loaded_a : vector<8xi16> -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( -llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { - // CHECK: xevm.DecorationCacheControl = - // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32 - // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32 - %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y - <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, - pack_register=false, cache_control=#xevm.load_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> - llvm.return %loaded_a : vector<8xi16> -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt( -// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, -// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockload2d_v_blocks(%[[ARG0:.*]]: !llvm.ptr<1>, -// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -llvm.func @blockload2d_v_blocks(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<16xi16> { - // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> - // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> - // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> - // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt( - // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) - // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, - // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = - // CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, - // CHECK-SAME: will_return} - // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, - // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () - // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<16xi16> - %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y - <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=2 : i32, transpose=false, - pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> - llvm.return %loaded_a : vector<16xi16> -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj( -// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, -// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockload2d_pack_register(%[[ARG0:.*]]: !llvm.ptr<1>, -// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -llvm.func @blockload2d_pack_register(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { - // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> - // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> - // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> - // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj( - // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) - // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, - // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = - // CHECK-SAME: "_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, - // CHECK-SAME: will_return} : - // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, - // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () - // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32> - %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y - <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=false, - pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> - llvm.return %loaded_a : vector<8xi32> -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj( -// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, -// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockload2d_transpose(%[[ARG0:.*]]: !llvm.ptr<1>, -// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -llvm.func @blockload2d_transpose(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> { - // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> - // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> - // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> - // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj( - // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) - // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, - // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = - // CHECK-SAME: "_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, - // CHECK-SAME: will_return} - // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, - // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> () - // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32> - %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y - <{elem_size_in_bits=32 : i32, tile_width=8 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=true, - pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> - llvm.return %loaded_a : vector<8xi32> -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( -// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, -// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} -// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, -// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) { -llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { - // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> - // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> - // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> - // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr - // CHECK: llvm.store %[[ARG6]], %[[VAR6]] : vector<8xi32>, !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( - // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) - // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>, ptr)>, - // CHECK-SAME: linkage = #llvm.linkage, no_unwind, sym_name = - // CHECK-SAME: "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, - // CHECK-SAME: will_return} - // CHECK-SAME: : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, - // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) -> () - xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted - <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}> - : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) - llvm.return -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj( -llvm.func @blockstore2d_cache_control(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { - // CHECK: xevm.DecorationCacheControl = - // CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32 - // CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32 - xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted - <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control = #xevm.store_cache_control}> - : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) - llvm.return -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( -// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes -// CHECK-SAME: {memory_effects = #llvm.memory_effects, no_unwind} -// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, -// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) { -llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) { - // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32> - // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32> - // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32> - // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( - // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) - // CHECK-SAME: {function_type = !llvm.func, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage, - // CHECK-SAME: memory_effects = #llvm.memory_effects, no_unwind, - // CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64 - xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y - <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, - cache_control=#xevm.load_cache_control}> - : (!llvm.ptr<1>, i32, i32, i32, i32, i32) - llvm.return -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( -// CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes -// CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} -// CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> { -llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { - // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( - // CHECK-SAME: %[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type = - // CHECK-SAME: !llvm.func (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage, - // CHECK-SAME: memory_effects = #llvm.memory_effects, no_unwind, - // CHECK-SAME: sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return} - // CHECK-SAME: : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> - %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted - { shape=, types= } - : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> - llvm.return %c_result : vector<8xf32> -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z22atomic_work_item_fenceiii(i32, i32, i32) attributes {no_unwind} -llvm.func @memfence() { - // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]]) - // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, no_unwind, - // CHECK-SAME: sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> () - xevm.memfence <{addrspace=#xevm.addr_space, scope=#xevm.mem_scope}> - llvm.return -} - -// ----- -// CHECK-LABEL: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes -// CHECK-SAME: {memory_effects = #llvm.memory_effects, no_unwind} -// CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) { -llvm.func @prefetch(%ptr: !llvm.ptr<1>) { - // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64 - // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]]) - // CHECK-SAME: {function_type = !llvm.func, i64)>, linkage = #llvm.linkage, - // CHECK-SAME: memory_effects = #llvm.memory_effects, - // CHECK-SAME: no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64 - xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) - llvm.return -} -