From 988bf661a4a59441f2a657b7b3e540b414f1cda7 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Fri, 28 Mar 2025 13:04:30 +0000 Subject: [PATCH 01/11] [uArch][XeGPU] Add uArch definition. --- .../mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h | 100 ++++++ mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h | 285 ++++++++++++++++++ mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt | 6 +- mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp | 48 +++ .../Dialect/XeGPU/Utils/intel_gpu_pvc.yaml | 205 +++++++++++++ 5 files changed, 643 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h create mode 100644 mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h create mode 100644 mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp create mode 100644 mlir/lib/Dialect/XeGPU/Utils/intel_gpu_pvc.yaml diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h new file mode 100644 index 0000000000000..0b1328b55ed86 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h @@ -0,0 +1,100 @@ +//===--- uArch.h ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// PVC uArch definition. +/// +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_PVC_H +#define MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_PVC_H + +#include "mlir/Dialect/XeGPU/Utils/uArch.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include +#include +#include + +namespace mlir { +namespace xegpu { +namespace uArch { +namespace PVCuArch { +struct XeCoreInfo { + uint num_threads; + SharedMemory shared_memory; + uint num_vector_units; + uint num_matrix_units; +}; + +struct Xe2Plus : public uArch { + XeCoreInfo xe_core; +}; + +// struct to represent DPAS instruction +struct DPASInstruction : public Instruction { + Range systolic_depth; + Range repreat_count; + Range execution_size; + std::map ops_per_channel; + std::vector> supported_types; + std::map>> + matrix_size; + + bool checkSupportedDPASTypes(mlir::Type dstType, mlir::Type src0Type, + mlir::Type src1Type, mlir::Type src2Type); +}; + +struct LoadStore2DTileInfo : public RangeTile { + std::vector array_len; +}; + +// struct to represent Load2D/Store2D/Prefetch instruction +struct LoadStorePrefetch2DInstruction : public Instruction { + MemoryType memory_type; + MemoryAccessType memory_access_type; + // std::vector supported_types; + std::vector supported_types_bitwidth; + std::map alignment; + LoadStore2DTileInfo supported_tile_sizes; + uint min_surface_pitch; + + // Validate Array length restriction on a given tile + bool validateArrayLenRestriction(Tile tile, uint array_len, + mlir::Type dataType) { + + Restriction width_array_len_restriction( + tile, array_len, dataType, + [](Tile tile, uint array_len, mlir::Type dataType) { + assert(tile.no_of_dims == 2); + return tile.dims[1] * array_len * + (dataType.getIntOrFloatBitWidth() / 8) <= + 64; + }); + return width_array_len_restriction.validate(); + } + + // Validate Surface Pitch restriction on a given tile + bool validateSurfacePitchRestriction(Tile tile, + uint surfacePitch /*in bytes*/) { + Restriction surface_pitch_restriction( + tile, surfacePitch, [](Tile tile, uint surfacePitch) { + assert(tile.no_of_dims == 2); + return surfacePitch >= 64; + }); + return surface_pitch_restriction.validate(); + } +}; + +} // namespace PVCuArch +} // namespace uArch +} // namespace xegpu +} // namespace mlir + +#endif // MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_PVC_H +//===--- IntelGpuPVC.h ---------------------------------------*- C++ -*-===// diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h b/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h new file mode 100644 index 0000000000000..4c23eaec873b1 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h @@ -0,0 +1,285 @@ +//===--- uArch.h ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// Base uArch definition for different architectures. +/// +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_XEGPU_UTILS_UARCH_H +#define MLIR_DIALECT_XEGPU_UTILS_UARCH_H + +#include +#include +#include +namespace mlir { +namespace xegpu { +namespace uArch { + +// Data types we need for YAML to uArch translation +struct Range { + int start; + int end; +}; + +// Tile can be multi-dimensional +// For example, a 2D tile can be represented as: +// Tile: +// no_of_dims: 2 +// dim: [2, 2] +// This represents a 2x2 tile +struct Tile { + uint no_of_dims; + std::vector dims; +}; + +// RangeTile represents a range of tiles instead of a single tile +// RangeTile essentially provides a way represent the supported range of values +// in each dimension For each dimension, the range of values is represented as a +// Range For example, a 2D RangeTile can be represented as: RangeTile: +// no_of_dims: 2 +// dims: +// - [1, 32] +// - [2, 16] +// This represents a 2x2 RangeTile where the first dimension can have values +// from 1 to 32 and the second dimension can have values from 2 to 16 +struct RangeTile { + uint no_of_dims; + std::vector dims; +}; + +// DiscreteTile represents a set of tiles instead of a single tile +// DiscreteTile essentially provides a way represent the supported set of values +// in each dimension For each dimension, the set of values is represented as a +// vector of integers For example, a 2D DiscreteTile can be represented as: +// DiscreteTile: +// no_of_dims: 2 +// dims: +// - [1, 2, 4, 8, 16, 32] +// - [2, 4, 8, 16] +// This represents a 2x2 DiscreteTile where the first dimension can have values +// 1, 2, 4, 8, 16, 32 and the second dimension can have values 2, 4, 8, 16 +struct DiscreteTile { + uint no_of_dims; + std::vector> dims; +}; + +// Restriction struct +// This struct is used to represent a restriction on the uArch +// The restriction is represented as a range of necessary parameters (template +// arguments) and a lambda function (validate()) that takes the same number of +// arguments as the number of template arguments The lambda function returns +// true if the arguments satisfy the restriction The lambda function returns +// false if the arguments do not satisfy the restriction + +// For example, a restriction that checks if the number of dimensions in a +// RangeTile is 2 can be represented as: RangeTile rt = {2, {{1, 32}, {2, 16}}}; +// Restriction r1(rt, [](RangeTile t) { return t.no_of_dims == 2; }); +// r1.validate() will return true if the number of dimensions in the RangeTile +// is 2 r1.validate() will return false if the number of dimensions in the +// RangeTile is not 2 + +// The primary purpose of Restriction struct is to provide a generic way to +// represent restrictions on the uArch and to validate if the uArch satisfies +// the restrictions +template +struct Restriction { + std::tuple data; + std::function func; + + Restriction(Args... args, std::function f) + : data(args...), func(f) {} + + bool validate() { return std::apply(func, data); } + std::any apply() { return std::apply(func, data); } +}; + +// An enum class to represent the functional unit of an instruction +enum class FunctionalUnit { + ALU, + Tensor, + Matrix, + Load, + Store, + Branch, + Barrier, + Memory, + Atomic, + Interconnect, + Other +}; + +// An enum class to represent the type of memory +enum class MemoryType { Shared, Local, Global, Constant, Texture, Other }; + +// An enum class to represent the memory access type +enum class MemoryAccessType { Read, Write, ReadWrite, Other }; + +// An enum class to represent the type of an instruction +enum class InstructionType { SIMT, SIMD, SPMD, MIMD, Other }; + +// An enum class to represent the scope of an instruction +enum class InstructionScope { + WorkItem, + Subgroup, + Workgroup, + Cluster, + Thread, // For CPU + Core, // For CPU + Other +}; + +// An enum class to represent the unit of computation of an instruction +enum class UnitOfComputation { + Scalar, + Vector, // 1-D vector + Matrix, + Tile, + Other +}; + +// A struct to represent basic information about an instruction +// This struct is used to represent the information about an instruction in the +// uArch The information includes: +// - the name of the instruction, +// - the opcode, +// - the functional unit, +// - the type of the instruction, +// - the scope of the instruction, +// - the unit of computation, +// - the description of the instruction +// The information is represented as strings +// For example, the information about an instruction can be represented as: +// Instruction info = {"dpas", "0x83", "matrix", "simd", "subgroup", "tile", +// "Dot Product Accumulate Systolic (DPAS) is a matrix multiply-add +// operation"}; + +// The primary purpose of Instruction struct is to provide a generic way to +// represent information about an instruction and to use this information to +// generate the uArch. Specifc instruction in a uArch can inherit from this +// struct and add more fields as needed + +struct Instruction { + std::string name; + std::string description; + std::string opcode; + FunctionalUnit functional_unit; + InstructionType type; + InstructionScope scope; + UnitOfComputation unit_of_computation; + + // @TODO: Add more fields as needed + // std::string latency; + // std::string throughput; + // std::string pipeline; + // std::string resource; + // std::string comment; +}; + +// A struct to represent register file information +struct RegisterFileInfo { + uint size; // size per register in bits + std::vector mode; // e.g., "small", "large" GRF modes + std::vector + num_regs_per_thread_per_mode; // number of registers per thread per mode + uint num_banks; + uint bank_size; +}; + +// A struct to represent cache information +struct CacheInfo { + uint size; + uint associativity; + uint line_size; + uint num_banks; + uint bank_size; + uint num_ports; + uint port_width; + uint bank_conflicts; +}; + +// A struct to represent the uArch +// This struct is used to represent the microarchitecture of a target device +// The uArch includes: +// - the name of the uArch, +// - the description of the uArch, +// - the range of tiles supported by the uArch, +// - the set of tiles supported by the uArch, +// - the set of instructions supported by the uArch, +// - the set of restrictions on the uArch +// The information is represented as strings, RangeTile, DiscreteTile, +// Instruction and Restriction structs For example, the information about a +// uArch can be represented as: uArch uarch = {"XeHPG", "Intel Xe HPG +// microarchitecture", {2, {{1, 32}, {1, 32}}}, {2, {{1, 2, 4, 8, 16, 32}, {1, +// 2, 4, 8, 16, 32}}}, {{"dpas", "0x83", "matrix", "simd", "subgroup", "tile", +// "Dot Product Accumulate Systolic (DPAS) is a matrix multiply-add +// operation"}}, {r1, r2, r3}}; This represents a uArch named "XeHPG" with +// description "Intel Xe HPG microarchitecture" that supports 2x2 tiles with +// dimensions ranging from 1 to 32, 1 to 32, supports a DPAS instruction and has +// 3 restrictions r1, r2, r3 on the uArch +struct uArch { + std::string name; // similar to target triple + std::string description; + // Different kind of regiger file information (e.g., GRF, ARF, etc.) + std::vector register_file_info; + // Each level of cache is indexed lower to higher in the vector + // (e.g., L1 indexed at 0, L2 at 1 and so on) L1, L2, L3, etc. + std::vector cache_info; + std::vector instructions; + std::vector *> restrictions; +}; + +// A struct to represent shared memory information +struct SharedMemory { + uint size; + uint alignment; + // @TODO: Add more fields as needed + // uint latency; + // uint throughput; + // uint bandwidth; + // uint num_ports; + // uint port_width; + // uint bank_size; + // uint bank_conflicts; + // uint num_banks; +}; + +// For future use case in Xe4+ + +// struct EUInfo { +// uint num_eu_threads; +// SharedMemory shared_memory; +// }; + +// uint num_simd_units; +// uint num_spus; +// uint num_smt; +// uint num_hardware_threads; +// uint num_threads_per_spu; +// uint num_threads_per_simd_unit; +// uint num_threads_per_hardware_thread; +// uint num_threads_per_smt; +// SharedMemory shared_memory; +// }; + +// A struct to represent a GPU uArch +// This struct is used to represent the GPU microarchitecture of a target device +// struct GPUuArch : public uArch { +// uint num_compute_units; +// uint num_vector_units; +// uint num_scalar_units; +// uint num_tensor_units; +// uint num_matrix_units; +// SharedMemory shared_memory; +// }; +} // namespace uArch +} // namespace xegpu +} // namespace mlir + +#endif // MLIR_DIALECT_XEGPU_UTILS_UARCH_H +//===--- uArch.h ---------------------------------------*- C++ -*-===// diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt index 98e84a4420722..f5736b90ea419 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt @@ -1,11 +1,15 @@ add_mlir_dialect_library(MLIRXeGPUUtils + IntelGpuPVC.cpp XeGPUUtils.cpp + ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/Utils LINK_LIBS PUBLIC MLIRIR + MLIRDialectUtils MLIRSCFTransforms MLIRXeGPUDialect - ) +) + diff --git a/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp b/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp new file mode 100644 index 0000000000000..167f62b2bcab4 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp @@ -0,0 +1,48 @@ +#include "mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h" +#include "llvm/Support/YAMLTraits.h" +#include +#include +#include + +using namespace mlir::xegpu::uArch; +using namespace mlir::xegpu::uArch::PVCuArch; + +namespace llvm { +namespace yaml { +template <> +struct MappingTraits { + static void mapping(IO &io, XeCoreInfo &xe_core) { + io.mapRequired("num_threads", xe_core.num_threads); + io.mapRequired("shared_memory", xe_core.shared_memory); + io.mapRequired("num_vector_units", xe_core.num_vector_units); + io.mapRequired("num_matrix_units", xe_core.num_matrix_units); + } +}; + +template <> +struct MappingTraits { + static void mapping(IO &io, Xe2Plus &xe2plus) { + io.mapRequired("xe_core", xe2plus.xe_core); + } +}; +} // namespace yaml +} // namespace llvm + +// namespace mlir { +// namespace xe_gpu { +// namespace namespace mlir { +// namespace xegpu { +// namespace PVCuArchYAML { { +// struct XeCoreInfo { +// uint num_threads; +// SharedMemory shared_memory; +// uint num_vector_units; +// uint num_matrix_units; +// }; + +// struct Xe2Plus { +// XeCoreInfo xe_core; +// }; +// } +// } +// } diff --git a/mlir/lib/Dialect/XeGPU/Utils/intel_gpu_pvc.yaml b/mlir/lib/Dialect/XeGPU/Utils/intel_gpu_pvc.yaml new file mode 100644 index 0000000000000..2ec52da7beb20 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Utils/intel_gpu_pvc.yaml @@ -0,0 +1,205 @@ +# This file contains the architecture-specific information for the Intel GPU codename PVC. + + +--- +# The target architecture name. +arch: intel_gpu_pvc + +# EU (VetorEngine) overview +number_of_eus_per_xe_core: 8 +number_of_threads_per_eu: [8, 4] # Number of threads in [small, large] grf_mode. + + +shared_memory: + # Shared memory size in bytes. + size: 524288 + # Shared memory alignment in bytes. + alignment: 64 + +# General Register File (GRF) information. +grf_bitwidth: 512 # GRF bitwidth in bits. +grf_mode: [small, large] +grf_count_per_thread: [128, 256] # Number of GRFs in [small, large] grf_mode. + + +# Instruction set information. + +instruction: + dpas: + instruction_name: dpas + instruction_description: "Dot Product Accumulate Systolic (DPAS) is a matrix multiply-add operation" + instruction_opcode: "0x83" + instruction_functional_unit: matrix + instruction_type: simd + instruction_scope: sub_group + instruction_unit_of_computation: tile + + # Information about the instruction. + systolic_depth: [8] + repeat_count: [1, 8] # [min, max] + execution_size: [16, 16] # [min, max] + ops_per_channel: + # number_of_bits_per_element: ops_per_channel + 19: 1 # tf32 has 19bits + 16: 2 # f16/bf16 has 16bits + 8: 4 # i8/u8/f8 has 8bits + 4: 8 # i4/u4/f4 has 4bits + 2: 8 # i2/u2 has 2bits + supported_types: + # [Dst, Acc (src0), A (src1), B (src2)] + - [f16, f16, f16, f16] + - [f32, f16, f16, f16] + - [f16, f32, f16, f16] + - [f32, f32, f16, f16] + - [bf16, bf16, bf16, bf16] + - [f32, bf16, bf16, bf16] + - [bf16, f32, bf16, bf16] + - [f32, f32, bf16, bf16] + - [f32, f32, tf32, tf32] + matrix_size: + M: + all_bitwidth: [1, 2, 3, 4, 5, 6, 7, 8] + K: # data type bitwidth: K size + 32: [8] + 16: [16] + 8: [32] + 4: [64] + 2: [64] + N: + all_bitwidth: [16] + + # Load 2D instruction. 2D Block Loads message reads a rectangular block of memory in to GRF. + # Upto 32 GRFs (=2048 bytes) can be read using this message. + # Block_width times array_size should not exceed 64 bytes. + load_2d: + instruction_name: load_2d + instruction_description: "2D Block Load instruction reads a rectangular block of memory in to GRF." + instruction_opcode: nan + instruction_functional_unit: load + instruction_type: simd + instruction_scope: sub_group + instruction_unit_of_computation: tile + instruction_type: memory + + # Information about the instruction. + memory_type: global # global=ugm + memory_access_type: read + # supported_types: [i8, i16, i32, i64, ui8, ui16, ui32, ui64, si8, si16, si32, si64, f8, f16, bf16, f32, f64] + supported_types_bitwidth: [8, 16, 32, 64] + alignment: + 8: 8 + 16: 16 + 32: 32 + 64: 64 + + # @TODO: Alternate design: We could do the following in C++ code using Restriction. + # alignment: data_type_bitwidth + + # tile_size: [height, width, array_len] + # height: [min, max] + # width: [min, max] + # array_len: [supported array lengths] + tile_size: # bitwidth: [height, width, array_len] + transpose: + 32: + height: [1, 32] + width: [1, 16] + array_len: [1] + max_width_with_array_len: 16 # width * array_len <= 16 + 64: + height: [1, 32] + width: [1, 8] + array_len: [1] + max_width_with_array_len: 8 # width * array_len <= 8 + vnni: + 8: + height: [4, 32] + width: [4, 16] + array_len: [1, 2, 4] + max_width_with_array_len: 64 # width * array_len <= 64 + 16: + height: [2, 32] + width: [2, 16] + array_len: [1, 2, 4] + max_width_with_array_len: 32 # width * array_len <= 32 + default: + 8: + height: [1, 32] + width: [4, 64] + array_len: [1, 2, 4] + max_width_with_array_len: 64 # width * array_len <= 64 + 16: + height: [1, 32] + width: [2, 32] + array_len: [1, 2, 4] + max_width_with_array_len: 32 # width * array_len <= 32 + 32: + height: [1, 32] + width: [1, 16] + array_len: [1, 2] + max_width_with_array_len: 16 # width * array_len <= 16 + 64: + height: [1, 32] + width: [1, 8] + array_len: [1] + max_width_with_array_len: 8 # width * array_len <= 8 + max_tile_height: 32 + max_tile_width: 64 # 64 bytes. width X array_length <= 64 bytes. + min_surface_pitch: 64 # 64 bytes. + + # Store 2D instruction. 2D Block Store message writes a rectangular block of GRF to memory. + # Upto 512 bytes of GRFs are can be written using this message. + # Block_width times array_len should not exceed 64 bytes. + store_2d: + instruction_name: store_2d + instruction_description: "2D Block Store instruction writes a rectangular block of to memory from GRF." + instruction_opcode: 0x00 + instruction_functional_unit: store + instruction_type: simd + instruction_scope: sub_group + instruction_unit_of_computation: tile + instruction_type: memory + + # Information about the instruction. + memory_type: global # global=ugm + memory_access_type: write + # supported_types: [i8, i16, i32, i64, ui8, ui16, ui32, ui64, si8, si16, si32, si64, f8, f16, bf16, f32, f64] + supported_types_bitwidth: [8, 16, 32, 64] + alignment: + 8: 8 + 16: 16 + 32: 32 + 64: 64 + + # @TODO: Alternate design: We could do the following in C++ code using Restriction. + # alignment: data_type_bitwidth + + # tile_size: [height, width, array_len] + # height: [min, max] + # width: [min, max] + # array_len: [supported array lengths] + tile_size: # bitwidth: [height, width, array_len] + default: + 8: + height: [1, 8] + width: [4, 64] + array_len: [1] + max_width_with_array_len: 64 # width * array_len <= 64 + 16: + height: [1, 8] + width: [2, 32] + array_len: [1] + max_width_with_array_len: 32 # width * array_len <= 32 + 32: + height: [1, 8] + width: [1, 16] + array_len: [1] + max_width_with_array_len: 16 # width * array_len <= 16 + 64: + height: [1, 8] + width: [1, 8] + array_len: [1] + max_width_with_array_len: 8 # width * array_len <= 8 + max_tile_height: 32 + max_tile_width: 64 # 64 bytes. width X array_length <= 64 bytes. + min_surface_pitch: 64 # 64 bytes. From f853e5a5f6fddacd0b805c582946c595292a9c4b Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 8 Apr 2025 18:06:40 +0000 Subject: [PATCH 02/11] Address review comments. --- mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h index 0b1328b55ed86..027213cab3b1d 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h @@ -97,4 +97,3 @@ struct LoadStorePrefetch2DInstruction : public Instruction { } // namespace mlir #endif // MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_PVC_H -//===--- IntelGpuPVC.h ---------------------------------------*- C++ -*-===// From 44ea3a4f844cac5021e7e41b3bf912ab16ce35c4 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Thu, 26 Jun 2025 16:05:13 +0000 Subject: [PATCH 03/11] Modify the uArch definition. This version focuses on the utilities to be the pivot. It also saves info directly in C++ files as part of get functions. Don't use the yamls anymore. Adds support for DPAS instruction. --- .../mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h | 28 +-- mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h | 29 +++ mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp | 168 ++++++++++++++++-- 3 files changed, 198 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h index 027213cab3b1d..f82426f0807f3 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h @@ -37,17 +37,25 @@ struct Xe2Plus : public uArch { }; // struct to represent DPAS instruction -struct DPASInstruction : public Instruction { - Range systolic_depth; - Range repreat_count; - Range execution_size; - std::map ops_per_channel; - std::vector> supported_types; - std::map>> - matrix_size; +struct DPASInstruction : public Instruction, public MatrixOpInterface { + // Range systolic_depth; + // Range repreat_count; + // Range execution_size; + // std::map ops_per_channel; + // std::vector> supported_types; + // std::map>> + // matrix_size; - bool checkSupportedDPASTypes(mlir::Type dstType, mlir::Type src0Type, - mlir::Type src1Type, mlir::Type src2Type); + // bool checkSupportedDPASTypes(mlir::Type dstType, mlir::Type src0Type, + // mlir::Type src1Type, mlir::Type src2Type); + virtual bool checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, + mlir::Type CType, + mlir::Type DType) override; + virtual std::vector getSupportedM(mlir::Type type) override; + virtual std::vector getSupportedK(mlir::Type type) override; + virtual std::vector getSupportedN(mlir::Type type) override; + virtual std::vector> + getSupportedMatrix(mlir::Type type, MatrixType matrixType) override; }; struct LoadStore2DTileInfo : public RangeTile { diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h b/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h index 4c23eaec873b1..971694fc93599 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h @@ -277,6 +277,35 @@ struct SharedMemory { // uint num_matrix_units; // SharedMemory shared_memory; // }; + +// Create a TileLikeOp Interface +struct TileOpInterface { + // Get the supported tiles for the specific data type. + // Can provide load/store/prefetch ops supported tile sizes for a specific + // uarch + virtual DiscreteTile getSupportedTiles(mlir::Type type) = 0; + + // Validate the tile ops restrictions + // @param tile, tile to load/store/prefetch + // @param surface, surface to load/store/prefetch data from + // @param dataType, data type of the data + // @param surface_pitch, suface pitch + // @param array_len, array length + virtual bool validate(Tile tile, Tile surface, mlir::Type dataType, + uint surface_pitch, uint array_len = 1) = 0; +}; + +enum class MatrixType { MatrixA, MatrixB, MatrixC, MatrixD }; +struct MatrixOpInterface { + virtual bool checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, + mlir::Type CType, mlir::Type DType) = 0; + virtual std::vector getSupportedM(mlir::Type type) = 0; + virtual std::vector getSupportedK(mlir::Type type) = 0; + virtual std::vector getSupportedN(mlir::Type type) = 0; + virtual std::vector> + getSupportedMatrix(mlir::Type type, MatrixType matrixType) = 0; +}; + } // namespace uArch } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp b/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp index 167f62b2bcab4..2fee6d306f0f9 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp @@ -7,26 +7,160 @@ using namespace mlir::xegpu::uArch; using namespace mlir::xegpu::uArch::PVCuArch; -namespace llvm { -namespace yaml { -template <> -struct MappingTraits { - static void mapping(IO &io, XeCoreInfo &xe_core) { - io.mapRequired("num_threads", xe_core.num_threads); - io.mapRequired("shared_memory", xe_core.shared_memory); - io.mapRequired("num_vector_units", xe_core.num_vector_units); - io.mapRequired("num_matrix_units", xe_core.num_matrix_units); +namespace mlir { +namespace xegpu { +namespace uArch { +namespace PVCuArch { +bool DPASInstruction::checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, + mlir::Type CType, + mlir::Type DType) { + if (AType.isF16() || BType.isF16()) { + if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) || + (!DType.isF32() && !DType.isF16())) + llvm::errs() + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " + << "Supported types are:\n" + << " Dst | Acc | A | B \n" + << " f, hf | f, hf | hf | hf \n" + << "AType: " << AType << " BType: " << BType << " CType: " << CType + << " DType: " << DType; + return false; + } else if (AType.isBF16() || BType.isBF16()) { + if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) || + (!DType.isF32() && !DType.isBF16())) + llvm::errs() + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " + << "Supported types are:\n" + << " Dst | Acc | A | B \n" + << " f, bf | f, bf | bf | bf \n" + << "AType: " << AType << " BType: " << BType << " CType: " << CType + << " DType: " << DType; + return false; + } else if (AType.isTF32() || BType.isTF32()) { + if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) || + (!DType.isF32())) + llvm::errs() + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " + << "Supported types are:\n" + << " Dst | Acc | A | B \n" + << " f | f | tf32 | tf32 \n" + << "AType: " << AType << " BType: " << BType << " CType: " << CType + << " DType: " << DType; + return false; + } else if (!(AType.isInteger(2) || AType.isInteger(4) || + AType.isInteger(8)) && + !(BType.isInteger(2) || BType.isInteger(4) || + BType.isInteger(8))) { + llvm::errs() + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " + << "Supported types are:\n" + << " Dst | Acc | A | B " + " \n" + << " ud, d | ud,d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 " + << "AType: " << AType << " BType: " << BType << " CType: " << CType + << " DType: " << DType; + return false; } -}; -template <> -struct MappingTraits { - static void mapping(IO &io, Xe2Plus &xe2plus) { - io.mapRequired("xe_core", xe2plus.xe_core); + return true; +} + +std::vector DPASInstruction::getSupportedM(mlir::Type type) { + return {1, 2, 3, 4, 5, 6, 7, 8}; +} + +std::vector DPASInstruction::getSupportedK(mlir::Type type) { + // assert if data type is not int or float type + assert(type.isIntOrFloat() && "Matrix type must be int or float"); + auto bitWidth = type.getIntOrFloatBitWidth(); + uint kSize = -1; + switch (bitWidth) { + case 2: + kSize = 64; + break; + case 4: + kSize = 64; + break; + case 8: + kSize = 32; + break; + case 16: + kSize = 16; + break; + case 32: + kSize = 8; + break; + default: + llvm_unreachable("Invalid int or float"); + } +} + +std::vector DPASInstruction::getSupportedN(mlir::Type type) { + return {16}; +} + +std::vector> +DPASInstruction::getSupportedMatrix(mlir::Type type, MatrixType matrixType) { + auto combineVectors = [](const std::vector &a, + const std::vector &b) + -> std::vector> { + std::vector> result; + for (unsigned x : a) { + for (unsigned y : b) { + result.emplace_back(x, y); + } + } + return result; + }; + + auto M = getSupportedM(type); + auto K = getSupportedK(type); + auto N = getSupportedN(type); + std::vector> resultMatrix; + + switch (matrixType) { + case MatrixType::MatrixA: + resultMatrix = combineVectors(M, K); + break; + case MatrixType::MatrixB: + resultMatrix = combineVectors(K, N); + break; + case MatrixType::MatrixC: + resultMatrix = combineVectors(M, N); + break; + case MatrixType::MatrixD: + resultMatrix = combineVectors(M, N); + break; + default: + break; } -}; -} // namespace yaml -} // namespace llvm +} + +} // namespace PVCuArch +} // namespace uArch +} // namespace xegpu +} // namespace mlir + +// namespace llvm { +// namespace yaml { +// template <> +// struct MappingTraits { +// static void mapping(IO &io, XeCoreInfo &xe_core) { +// io.mapRequired("num_threads", xe_core.num_threads); +// io.mapRequired("shared_memory", xe_core.shared_memory); +// io.mapRequired("num_vector_units", xe_core.num_vector_units); +// io.mapRequired("num_matrix_units", xe_core.num_matrix_units); +// } +// }; + +// template <> +// struct MappingTraits { +// static void mapping(IO &io, Xe2Plus &xe2plus) { +// io.mapRequired("xe_core", xe2plus.xe_core); +// } +// }; +// } // namespace yaml +// } // namespace llvm // namespace mlir { // namespace xe_gpu { From 6c45d970fa64a25ef02f93aed1fc9010e9fe782a Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 1 Jul 2025 23:52:58 +0000 Subject: [PATCH 04/11] Add necessary infrastructures for the uArch to show the full pipeline. --- .../mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h | 75 +++++++++- mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h | 141 ++++++++++++++++-- mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp | 21 +-- 3 files changed, 216 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h index f82426f0807f3..f09ffea99b67a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h @@ -24,16 +24,33 @@ namespace mlir { namespace xegpu { namespace uArch { -namespace PVCuArch { +namespace Xe2Plus { struct XeCoreInfo { uint num_threads; SharedMemory shared_memory; uint num_vector_units; uint num_matrix_units; + + // Constructor + XeCoreInfo(uint num_threads, const SharedMemory &shared_memory, + uint num_vector_units, uint num_matrix_units) + : num_threads(num_threads), shared_memory(shared_memory), + num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) { + } }; struct Xe2Plus : public uArch { XeCoreInfo xe_core; + Xe2Plus(const std::string &archName, const std::string &archDescription, + const XeCoreInfo &xeCore, + const std::vector &hierarchy = {}, + const std::map ®Info = {}, + const std::vector &cacheInfo = {}, + const std::map &instrs = {}, + const std::vector *> &restrs = {}) + : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs, + restrs), + xe_core(xeCore) {} }; // struct to represent DPAS instruction @@ -48,6 +65,18 @@ struct DPASInstruction : public Instruction, public MatrixOpInterface { // bool checkSupportedDPASTypes(mlir::Type dstType, mlir::Type src0Type, // mlir::Type src1Type, mlir::Type src2Type); + + DPASInstruction() + : Instruction("dpas", // name + "Dot Product Accumulate", // description + "0xABCD", // opcode + FunctionalUnit::Matrix, // functional_unit + InstructionType::SIMD, // type + InstructionScope::Subgroup, // scope + UnitOfComputation::Matrix) // unit_of_computation + {} + + // Override all virtuals from MatrixOpInterface virtual bool checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) override; @@ -99,7 +128,51 @@ struct LoadStorePrefetch2DInstruction : public Instruction { } }; +namespace PVCuArch { +struct PVCuArch : public Xe2Plus { + // Maintaines ownership of the instructions owned by PVUarch + std::vector> owned_instructions; + PVCuArch() + : Xe2Plus("pvc", // archName + "Ponte Vecchio Architecture", // archDescription + XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore + {/* register_file_info */}, // Optional: empty + {/* cache_info */}, // Optional: empty + {/* instructions */}, // Optional: empty + {/* restrictions */} // Optional: empty + ) { + // Initialize uArchHierarchy + this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2)); + // Intialize register file info + // GRF + this->register_file_info["GRF"] = + RegisterFileInfo(64 * 1024, // size in bits + {"small", "large"}, // GRF modes + {128, 256}, // registers per thread per mode + 0, // number of banks + 0 // bank size + ); + // Initialize cache info + // L1 cache, XeCore level + this->cache_info.push_back( + CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1])); + // L3 cache, XeStack level + this->cache_info.push_back( + CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3])); + + // Add the instructions + auto dpas = std::make_unique(); + instructions[dpas->name] = dpas.get(); + owned_instructions.push_back(std::move(dpas)); + } +}; } // namespace PVCuArch + +} // namespace Xe2Plus } // namespace uArch } // namespace xegpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h b/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h index 971694fc93599..97777d4bb4e5a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h @@ -16,7 +16,10 @@ #include #include +#include +#include #include + namespace mlir { namespace xegpu { namespace uArch { @@ -99,6 +102,17 @@ struct Restriction { std::any apply() { return std::apply(func, data); } }; +// Architecture HW component hierarchy to present thread, core, socket ... +struct uArchHierarchyComponent { + std::string name = ""; // optional name of the hierarchy component + // no. of lower hierarchy component it contains, e.g., for PVC XeCore it + // contains 8 threads, so no_of_component=8 + uint no_of_component; + // Constructor + uArchHierarchyComponent(const std::string &name, uint no_of_component) + : name(name), no_of_component(no_of_component) {} +}; + // An enum class to represent the functional unit of an instruction enum class FunctionalUnit { ALU, @@ -179,6 +193,12 @@ struct Instruction { // std::string pipeline; // std::string resource; // std::string comment; + Instruction(std::string name, std::string desc, std::string opcode, + FunctionalUnit fu, InstructionType itype, InstructionScope sc, + UnitOfComputation uoc) + : name(std::move(name)), description(std::move(desc)), + opcode(std::move(opcode)), functional_unit(fu), type(itype), scope(sc), + unit_of_computation(uoc) {} }; // A struct to represent register file information @@ -189,18 +209,30 @@ struct RegisterFileInfo { num_regs_per_thread_per_mode; // number of registers per thread per mode uint num_banks; uint bank_size; + + // Constructor + RegisterFileInfo(uint size, const std::vector &mode, + const std::vector &numRegs, uint num_banks, + uint bank_size) + : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs), + num_banks(num_banks), bank_size(bank_size) {} }; // A struct to represent cache information struct CacheInfo { uint size; - uint associativity; uint line_size; - uint num_banks; - uint bank_size; - uint num_ports; - uint port_width; - uint bank_conflicts; + // At which component level the cache is shared + uArchHierarchyComponent component; + // uint associativity; + // uint num_banks; + // uint bank_size; + // uint num_ports; + // uint port_width; + // uint bank_conflicts; + // Constructor + CacheInfo(uint size, uint line_size, const uArchHierarchyComponent &component) + : size(size), line_size(line_size), component(component) {} }; // A struct to represent the uArch @@ -225,19 +257,38 @@ struct CacheInfo { struct uArch { std::string name; // similar to target triple std::string description; + // Represent the whole uArch hierarchy + // For 2 stack Intel PVC it would look something like this: + // uArchHierarchy[0] = {thread, 0} + // uArchHierarchy[1] = {XeCore, 8} + // uArchHierarchy[2] = {XeSlice, 16} + // uArchHierarchy[3] = {XeStack, 4} + // uArchHierarchy[4] = {gpu, 2} + std::vector uArch_hierarchy; // Different kind of regiger file information (e.g., GRF, ARF, etc.) - std::vector register_file_info; + std::map register_file_info; // Each level of cache is indexed lower to higher in the vector // (e.g., L1 indexed at 0, L2 at 1 and so on) L1, L2, L3, etc. std::vector cache_info; - std::vector instructions; + std::map instructions; std::vector *> restrictions; + + // Constructor + uArch(const std::string &name, const std::string &description, + const std::vector &uArch_hierarchy = {}, + const std::map ®ister_file_info = {}, + const std::vector &cache_info = {}, + const std::map &instructions = {}, + const std::vector *> &restrictions = {}) + : name(name), description(description), uArch_hierarchy(uArch_hierarchy), + register_file_info(register_file_info), cache_info(cache_info), + instructions(instructions), restrictions(restrictions) {} }; // A struct to represent shared memory information struct SharedMemory { - uint size; - uint alignment; + uint size; // in bytes + uint alignment; // in bytes // @TODO: Add more fields as needed // uint latency; // uint throughput; @@ -247,6 +298,9 @@ struct SharedMemory { // uint bank_size; // uint bank_conflicts; // uint num_banks; + + // Constructor + SharedMemory(uint size, uint alignment) : size(size), alignment(alignment) {} }; // For future use case in Xe4+ @@ -293,6 +347,7 @@ struct TileOpInterface { // @param array_len, array length virtual bool validate(Tile tile, Tile surface, mlir::Type dataType, uint surface_pitch, uint array_len = 1) = 0; + virtual ~TileOpInterface() = default; }; enum class MatrixType { MatrixA, MatrixB, MatrixC, MatrixD }; @@ -304,11 +359,75 @@ struct MatrixOpInterface { virtual std::vector getSupportedN(mlir::Type type) = 0; virtual std::vector> getSupportedMatrix(mlir::Type type, MatrixType matrixType) = 0; + + virtual ~MatrixOpInterface() = default; }; +struct uArchMap { +public: + // Singleton instance + static uArchMap &instance() { + static uArchMap instance; + return instance; + } + + // Insert or update a key-value pair + void insert(const std::string &key, uArch value) { + std::unique_lock lock(mutex_); + map_[key] = value; + } + + // Get a value by key (concurrent safe read) + std::optional get(const std::string &key) const { + std::shared_lock lock(mutex_); + auto it = map_.find(key); + if (it != map_.end()) + return it->second; + return std::nullopt; + } + + // Check if a key exists + bool contains(const std::string &key) const { + std::shared_lock lock(mutex_); + return map_.find(key) != map_.end(); + } + + // Remove a key + bool erase(const std::string &key) { + std::unique_lock lock(mutex_); + return map_.erase(key) > 0; + } + +private: + uArchMap() = default; + uArchMap(const uArchMap &) = delete; + uArchMap &operator=(const uArchMap &) = delete; + + mutable std::shared_mutex mutex_; + std::map map_; +}; + +// std::unordered_map uArchMap; +// std::shared_mutex uArchMapMutex; + +// void getuArch(const std::string &key) { +// std::shared_lock lock(uArchMapMutex); +// auto it = uArchMap.find(key); +// if(it != uArchMap.end()) +// return *it; +// else + +// // safe concurrent read +// } + +// void AdduArch(const std::string &key, uArch &value) { +// std::unique_lock lock(uArchMapMutex); + +// // exclusive write +// } + } // namespace uArch } // namespace xegpu } // namespace mlir #endif // MLIR_DIALECT_XEGPU_UTILS_UARCH_H -//===--- uArch.h ---------------------------------------*- C++ -*-===// diff --git a/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp b/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp index 2fee6d306f0f9..99fe136f22621 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp @@ -5,18 +5,18 @@ #include using namespace mlir::xegpu::uArch; -using namespace mlir::xegpu::uArch::PVCuArch; +using namespace mlir::xegpu::uArch::Xe2Plus; namespace mlir { namespace xegpu { namespace uArch { -namespace PVCuArch { +namespace Xe2Plus { bool DPASInstruction::checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) { if (AType.isF16() || BType.isF16()) { if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) || - (!DType.isF32() && !DType.isF16())) + (!DType.isF32() && !DType.isF16())) { llvm::errs() << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " << "Supported types are:\n" @@ -24,10 +24,11 @@ bool DPASInstruction::checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, << " f, hf | f, hf | hf | hf \n" << "AType: " << AType << " BType: " << BType << " CType: " << CType << " DType: " << DType; - return false; + return false; + } } else if (AType.isBF16() || BType.isBF16()) { if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) || - (!DType.isF32() && !DType.isBF16())) + (!DType.isF32() && !DType.isBF16())) { llvm::errs() << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " << "Supported types are:\n" @@ -35,10 +36,11 @@ bool DPASInstruction::checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, << " f, bf | f, bf | bf | bf \n" << "AType: " << AType << " BType: " << BType << " CType: " << CType << " DType: " << DType; - return false; + return false; + } } else if (AType.isTF32() || BType.isTF32()) { if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) || - (!DType.isF32())) + (!DType.isF32())) { llvm::errs() << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " << "Supported types are:\n" @@ -46,7 +48,8 @@ bool DPASInstruction::checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, << " f | f | tf32 | tf32 \n" << "AType: " << AType << " BType: " << BType << " CType: " << CType << " DType: " << DType; - return false; + return false; + } } else if (!(AType.isInteger(2) || AType.isInteger(4) || AType.isInteger(8)) && !(BType.isInteger(2) || BType.isInteger(4) || @@ -136,7 +139,7 @@ DPASInstruction::getSupportedMatrix(mlir::Type type, MatrixType matrixType) { } } -} // namespace PVCuArch +} // namespace Xe2Plus } // namespace uArch } // namespace xegpu } // namespace mlir From ffd9d699ef6f15e5ca7db213756a3e5d8de1bc8a Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Wed, 2 Jul 2025 00:50:50 +0000 Subject: [PATCH 05/11] Add BMG to the uArch. --- .../Utils/{IntelGpuPVC.h => IntelGpuXe2.h} | 44 +++++++++++++++++++ mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt | 2 +- .../{IntelGpuPVC.cpp => IntelGpuXe2.cpp} | 0 3 files changed, 45 insertions(+), 1 deletion(-) rename mlir/include/mlir/Dialect/XeGPU/Utils/{IntelGpuPVC.h => IntelGpuXe2.h} (78%) rename mlir/lib/Dialect/XeGPU/Utils/{IntelGpuPVC.cpp => IntelGpuXe2.cpp} (100%) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuXe2.h similarity index 78% rename from mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h rename to mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuXe2.h index f09ffea99b67a..67d025b037c5e 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuXe2.h @@ -172,6 +172,50 @@ struct PVCuArch : public Xe2Plus { }; } // namespace PVCuArch +namespace BMGuArch { +struct BMGuArch : public Xe2Plus { + // Maintaines ownership of the instructions owned by PVUarch + std::vector> owned_instructions; + BMGuArch() + : Xe2Plus("bmg", // archName + "Battlemage Architecture", // archDescription + XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore + {/* register_file_info */}, // Optional: empty + {/* cache_info */}, // Optional: empty + {/* instructions */}, // Optional: empty + {/* restrictions */} // Optional: empty + ) { + // Initialize uArchHierarchy + this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1)); + // Intialize register file info + // GRF + this->register_file_info["GRF"] = + RegisterFileInfo(64 * 1024, // size in bits + {"small", "large"}, // GRF modes + {128, 256}, // registers per thread per mode + 0, // number of banks + 0 // bank size + ); + // Initialize cache info + // L1 cache, XeCore level + this->cache_info.push_back( + CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1])); + // L3 cache, XeStack level + this->cache_info.push_back( + CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3])); + + // Add the instructions + auto dpas = std::make_unique(); + instructions[dpas->name] = dpas.get(); + owned_instructions.push_back(std::move(dpas)); + } +}; +} // namespace BMGuArch + } // namespace Xe2Plus } // namespace uArch } // namespace xegpu diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt index f5736b90ea419..d104495b42096 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect_library(MLIRXeGPUUtils - IntelGpuPVC.cpp + IntelGpuXe2.cpp XeGPUUtils.cpp diff --git a/mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp b/mlir/lib/Dialect/XeGPU/Utils/IntelGpuXe2.cpp similarity index 100% rename from mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp rename to mlir/lib/Dialect/XeGPU/Utils/IntelGpuXe2.cpp From af7098b5759eec5d2b4bee67596a1c835111c3b2 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Wed, 2 Jul 2025 04:45:53 +0000 Subject: [PATCH 06/11] Move uArch to a new folder. Remove compiler warnings and errors. --- .../XeGPU/{Utils => uArch}/IntelGpuXe2.h | 51 +++---- .../Dialect/XeGPU/{Utils => uArch}/uArch.h | 127 ++++++++++-------- mlir/lib/Dialect/XeGPU/CMakeLists.txt | 1 + mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt | 3 - mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt | 12 ++ .../XeGPU/{Utils => uArch}/IntelGpuXe2.cpp | 20 +-- .../XeGPU/{Utils => uArch}/intel_gpu_pvc.yaml | 0 7 files changed, 117 insertions(+), 97 deletions(-) rename mlir/include/mlir/Dialect/XeGPU/{Utils => uArch}/IntelGpuXe2.h (85%) rename mlir/include/mlir/Dialect/XeGPU/{Utils => uArch}/uArch.h (83%) create mode 100644 mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt rename mlir/lib/Dialect/XeGPU/{Utils => uArch}/IntelGpuXe2.cpp (92%) rename mlir/lib/Dialect/XeGPU/{Utils => uArch}/intel_gpu_pvc.yaml (100%) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h similarity index 85% rename from mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuXe2.h rename to mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h index 67d025b037c5e..679d8d833877a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuXe2.h +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h @@ -11,10 +11,10 @@ /// // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_PVC_H -#define MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_PVC_H +#ifndef MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_XE2_H +#define MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_XE2_H -#include "mlir/Dialect/XeGPU/Utils/uArch.h" +#include "mlir/Dialect/XeGPU/uArch/uArch.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include @@ -26,14 +26,14 @@ namespace xegpu { namespace uArch { namespace Xe2Plus { struct XeCoreInfo { - uint num_threads; + uint32_t num_threads; SharedMemory shared_memory; - uint num_vector_units; - uint num_matrix_units; + uint32_t num_vector_units; + uint32_t num_matrix_units; // Constructor - XeCoreInfo(uint num_threads, const SharedMemory &shared_memory, - uint num_vector_units, uint num_matrix_units) + XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory, + uint32_t num_vector_units, uint32_t num_matrix_units) : num_threads(num_threads), shared_memory(shared_memory), num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) { } @@ -58,7 +58,7 @@ struct DPASInstruction : public Instruction, public MatrixOpInterface { // Range systolic_depth; // Range repreat_count; // Range execution_size; - // std::map ops_per_channel; + // std::map ops_per_channel; // std::vector> supported_types; // std::map>> // matrix_size; @@ -80,15 +80,15 @@ struct DPASInstruction : public Instruction, public MatrixOpInterface { virtual bool checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) override; - virtual std::vector getSupportedM(mlir::Type type) override; - virtual std::vector getSupportedK(mlir::Type type) override; - virtual std::vector getSupportedN(mlir::Type type) override; + virtual std::vector getSupportedM(mlir::Type type) override; + virtual std::vector getSupportedK(mlir::Type type) override; + virtual std::vector getSupportedN(mlir::Type type) override; virtual std::vector> getSupportedMatrix(mlir::Type type, MatrixType matrixType) override; }; struct LoadStore2DTileInfo : public RangeTile { - std::vector array_len; + std::vector array_len; }; // struct to represent Load2D/Store2D/Prefetch instruction @@ -96,18 +96,18 @@ struct LoadStorePrefetch2DInstruction : public Instruction { MemoryType memory_type; MemoryAccessType memory_access_type; // std::vector supported_types; - std::vector supported_types_bitwidth; - std::map alignment; + std::vector supported_types_bitwidth; + std::map alignment; LoadStore2DTileInfo supported_tile_sizes; - uint min_surface_pitch; + uint32_t min_surface_pitch; // Validate Array length restriction on a given tile - bool validateArrayLenRestriction(Tile tile, uint array_len, + bool validateArrayLenRestriction(Tile tile, uint32_t array_len, mlir::Type dataType) { - Restriction width_array_len_restriction( + Restriction width_array_len_restriction( tile, array_len, dataType, - [](Tile tile, uint array_len, mlir::Type dataType) { + [](Tile tile, uint32_t array_len, mlir::Type dataType) { assert(tile.no_of_dims == 2); return tile.dims[1] * array_len * (dataType.getIntOrFloatBitWidth() / 8) <= @@ -118,9 +118,9 @@ struct LoadStorePrefetch2DInstruction : public Instruction { // Validate Surface Pitch restriction on a given tile bool validateSurfacePitchRestriction(Tile tile, - uint surfacePitch /*in bytes*/) { - Restriction surface_pitch_restriction( - tile, surfacePitch, [](Tile tile, uint surfacePitch) { + uint32_t surfacePitch /*in bytes*/) { + Restriction surface_pitch_restriction( + tile, surfacePitch, [](Tile tile, uint32_t surfacePitch) { assert(tile.no_of_dims == 2); return surfacePitch >= 64; }); @@ -149,13 +149,14 @@ struct PVCuArch : public Xe2Plus { this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2)); // Intialize register file info // GRF - this->register_file_info["GRF"] = + this->register_file_info.emplace( + "GRF", RegisterFileInfo(64 * 1024, // size in bits {"small", "large"}, // GRF modes {128, 256}, // registers per thread per mode 0, // number of banks 0 // bank size - ); + )); // Initialize cache info // L1 cache, XeCore level this->cache_info.push_back( @@ -221,4 +222,4 @@ struct BMGuArch : public Xe2Plus { } // namespace xegpu } // namespace mlir -#endif // MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_PVC_H +#endif // MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_XE2_H diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h similarity index 83% rename from mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h rename to mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h index 97777d4bb4e5a..cc04b18427ad8 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h @@ -14,12 +14,16 @@ #ifndef MLIR_DIALECT_XEGPU_UTILS_UARCH_H #define MLIR_DIALECT_XEGPU_UTILS_UARCH_H +#include #include #include +#include #include #include #include +#include "mlir/IR/Types.h" + namespace mlir { namespace xegpu { namespace uArch { @@ -37,8 +41,8 @@ struct Range { // dim: [2, 2] // This represents a 2x2 tile struct Tile { - uint no_of_dims; - std::vector dims; + uint32_t no_of_dims; + std::vector dims; }; // RangeTile represents a range of tiles instead of a single tile @@ -52,7 +56,7 @@ struct Tile { // This represents a 2x2 RangeTile where the first dimension can have values // from 1 to 32 and the second dimension can have values from 2 to 16 struct RangeTile { - uint no_of_dims; + uint32_t no_of_dims; std::vector dims; }; @@ -68,8 +72,8 @@ struct RangeTile { // This represents a 2x2 DiscreteTile where the first dimension can have values // 1, 2, 4, 8, 16, 32 and the second dimension can have values 2, 4, 8, 16 struct DiscreteTile { - uint no_of_dims; - std::vector> dims; + uint32_t no_of_dims; + std::vector> dims; }; // Restriction struct @@ -93,9 +97,9 @@ struct DiscreteTile { template struct Restriction { std::tuple data; - std::function func; + std::function func; - Restriction(Args... args, std::function f) + Restriction(Args... args, std::function f) : data(args...), func(f) {} bool validate() { return std::apply(func, data); } @@ -107,9 +111,9 @@ struct uArchHierarchyComponent { std::string name = ""; // optional name of the hierarchy component // no. of lower hierarchy component it contains, e.g., for PVC XeCore it // contains 8 threads, so no_of_component=8 - uint no_of_component; + uint32_t no_of_component; // Constructor - uArchHierarchyComponent(const std::string &name, uint no_of_component) + uArchHierarchyComponent(const std::string &name, uint32_t no_of_component) : name(name), no_of_component(no_of_component) {} }; @@ -203,35 +207,37 @@ struct Instruction { // A struct to represent register file information struct RegisterFileInfo { - uint size; // size per register in bits + uint32_t size; // size per register in bits std::vector mode; // e.g., "small", "large" GRF modes - std::vector + std::vector num_regs_per_thread_per_mode; // number of registers per thread per mode - uint num_banks; - uint bank_size; + uint32_t num_banks; + uint32_t bank_size; // Constructor - RegisterFileInfo(uint size, const std::vector &mode, - const std::vector &numRegs, uint num_banks, - uint bank_size) + RegisterFileInfo() = default; + RegisterFileInfo(uint32_t size, const std::vector &mode, + const std::vector &numRegs, uint32_t num_banks, + uint32_t bank_size) : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs), num_banks(num_banks), bank_size(bank_size) {} }; // A struct to represent cache information struct CacheInfo { - uint size; - uint line_size; + uint32_t size; + uint32_t line_size; // At which component level the cache is shared uArchHierarchyComponent component; - // uint associativity; - // uint num_banks; - // uint bank_size; - // uint num_ports; - // uint port_width; - // uint bank_conflicts; + // uint32_t associativity; + // uint32_t num_banks; + // uint32_t bank_size; + // uint32_t num_ports; + // uint32_t port_width; + // uint32_t bank_conflicts; // Constructor - CacheInfo(uint size, uint line_size, const uArchHierarchyComponent &component) + CacheInfo(uint32_t size, uint32_t line_size, + const uArchHierarchyComponent &component) : size(size), line_size(line_size), component(component) {} }; @@ -274,6 +280,7 @@ struct uArch { std::vector *> restrictions; // Constructor + uArch() = default; uArch(const std::string &name, const std::string &description, const std::vector &uArch_hierarchy = {}, const std::map ®ister_file_info = {}, @@ -287,48 +294,49 @@ struct uArch { // A struct to represent shared memory information struct SharedMemory { - uint size; // in bytes - uint alignment; // in bytes + uint32_t size; // in bytes + uint32_t alignment; // in bytes // @TODO: Add more fields as needed - // uint latency; - // uint throughput; - // uint bandwidth; - // uint num_ports; - // uint port_width; - // uint bank_size; - // uint bank_conflicts; - // uint num_banks; + // uint32_t latency; + // uint32_t throughput; + // uint32_t bandwidth; + // uint32_t num_ports; + // uint32_t port_width; + // uint32_t bank_size; + // uint32_t bank_conflicts; + // uint32_t num_banks; // Constructor - SharedMemory(uint size, uint alignment) : size(size), alignment(alignment) {} + SharedMemory(uint32_t size, uint32_t alignment) + : size(size), alignment(alignment) {} }; // For future use case in Xe4+ // struct EUInfo { -// uint num_eu_threads; +// uint32_t num_eu_threads; // SharedMemory shared_memory; // }; -// uint num_simd_units; -// uint num_spus; -// uint num_smt; -// uint num_hardware_threads; -// uint num_threads_per_spu; -// uint num_threads_per_simd_unit; -// uint num_threads_per_hardware_thread; -// uint num_threads_per_smt; +// uint32_t num_simd_units; +// uint32_t num_spus; +// uint32_t num_smt; +// uint32_t num_hardware_threads; +// uint32_t num_threads_per_spu; +// uint32_t num_threads_per_simd_unit; +// uint32_t num_threads_per_hardware_thread; +// uint32_t num_threads_per_smt; // SharedMemory shared_memory; // }; // A struct to represent a GPU uArch // This struct is used to represent the GPU microarchitecture of a target device // struct GPUuArch : public uArch { -// uint num_compute_units; -// uint num_vector_units; -// uint num_scalar_units; -// uint num_tensor_units; -// uint num_matrix_units; +// uint32_t num_compute_units; +// uint32_t num_vector_units; +// uint32_t num_scalar_units; +// uint32_t num_tensor_units; +// uint32_t num_matrix_units; // SharedMemory shared_memory; // }; @@ -346,7 +354,7 @@ struct TileOpInterface { // @param surface_pitch, suface pitch // @param array_len, array length virtual bool validate(Tile tile, Tile surface, mlir::Type dataType, - uint surface_pitch, uint array_len = 1) = 0; + uint32_t surface_pitch, uint32_t array_len = 1) = 0; virtual ~TileOpInterface() = default; }; @@ -354,9 +362,9 @@ enum class MatrixType { MatrixA, MatrixB, MatrixC, MatrixD }; struct MatrixOpInterface { virtual bool checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) = 0; - virtual std::vector getSupportedM(mlir::Type type) = 0; - virtual std::vector getSupportedK(mlir::Type type) = 0; - virtual std::vector getSupportedN(mlir::Type type) = 0; + virtual std::vector getSupportedM(mlir::Type type) = 0; + virtual std::vector getSupportedK(mlir::Type type) = 0; + virtual std::vector getSupportedN(mlir::Type type) = 0; virtual std::vector> getSupportedMatrix(mlir::Type type, MatrixType matrixType) = 0; @@ -373,13 +381,14 @@ struct uArchMap { // Insert or update a key-value pair void insert(const std::string &key, uArch value) { - std::unique_lock lock(mutex_); - map_[key] = value; + std::unique_lock lock(mutex_); + // map_[key] = value; + map_.emplace(key, value); } // Get a value by key (concurrent safe read) std::optional get(const std::string &key) const { - std::shared_lock lock(mutex_); + std::shared_lock lock(mutex_); auto it = map_.find(key); if (it != map_.end()) return it->second; @@ -388,13 +397,13 @@ struct uArchMap { // Check if a key exists bool contains(const std::string &key) const { - std::shared_lock lock(mutex_); + std::shared_lock lock(mutex_); return map_.find(key) != map_.end(); } // Remove a key bool erase(const std::string &key) { - std::unique_lock lock(mutex_); + std::unique_lock lock(mutex_); return map_.erase(key) > 0; } diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt index 31167e6af908b..9079df050ab2b 100644 --- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(uArch) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt index d104495b42096..8fa908087c0ae 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt @@ -1,14 +1,11 @@ add_mlir_dialect_library(MLIRXeGPUUtils - IntelGpuXe2.cpp XeGPUUtils.cpp - ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/Utils LINK_LIBS PUBLIC MLIRIR - MLIRDialectUtils MLIRSCFTransforms MLIRXeGPUDialect ) diff --git a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt new file mode 100644 index 0000000000000..b880f9abf04ac --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(MLIRXeGPUuArch + IntelGpuXe2.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/uArch + + LINK_LIBS PUBLIC + MLIRIR + MLIRDialectUtils + MLIRXeGPUDialect +) + diff --git a/mlir/lib/Dialect/XeGPU/Utils/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp similarity index 92% rename from mlir/lib/Dialect/XeGPU/Utils/IntelGpuXe2.cpp rename to mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp index 99fe136f22621..7172b2aa11f46 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/IntelGpuXe2.cpp +++ b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp @@ -1,4 +1,4 @@ -#include "mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "llvm/Support/YAMLTraits.h" #include #include @@ -68,15 +68,15 @@ bool DPASInstruction::checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, return true; } -std::vector DPASInstruction::getSupportedM(mlir::Type type) { +std::vector DPASInstruction::getSupportedM(mlir::Type type) { return {1, 2, 3, 4, 5, 6, 7, 8}; } -std::vector DPASInstruction::getSupportedK(mlir::Type type) { +std::vector DPASInstruction::getSupportedK(mlir::Type type) { // assert if data type is not int or float type assert(type.isIntOrFloat() && "Matrix type must be int or float"); auto bitWidth = type.getIntOrFloatBitWidth(); - uint kSize = -1; + uint32_t kSize = 0; switch (bitWidth) { case 2: kSize = 64; @@ -96,9 +96,10 @@ std::vector DPASInstruction::getSupportedK(mlir::Type type) { default: llvm_unreachable("Invalid int or float"); } + return {kSize}; } -std::vector DPASInstruction::getSupportedN(mlir::Type type) { +std::vector DPASInstruction::getSupportedN(mlir::Type type) { return {16}; } @@ -134,9 +135,8 @@ DPASInstruction::getSupportedMatrix(mlir::Type type, MatrixType matrixType) { case MatrixType::MatrixD: resultMatrix = combineVectors(M, N); break; - default: - break; } + return resultMatrix; } } // namespace Xe2Plus @@ -171,10 +171,10 @@ DPASInstruction::getSupportedMatrix(mlir::Type type, MatrixType matrixType) { // namespace xegpu { // namespace PVCuArchYAML { { // struct XeCoreInfo { -// uint num_threads; +// uint32_t num_threads; // SharedMemory shared_memory; -// uint num_vector_units; -// uint num_matrix_units; +// uint32_t num_vector_units; +// uint32_t num_matrix_units; // }; // struct Xe2Plus { diff --git a/mlir/lib/Dialect/XeGPU/Utils/intel_gpu_pvc.yaml b/mlir/lib/Dialect/XeGPU/uArch/intel_gpu_pvc.yaml similarity index 100% rename from mlir/lib/Dialect/XeGPU/Utils/intel_gpu_pvc.yaml rename to mlir/lib/Dialect/XeGPU/uArch/intel_gpu_pvc.yaml From 6a616037ad0bf12620d9e3029b959e73b1346a9b Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Wed, 2 Jul 2025 11:44:56 +0000 Subject: [PATCH 07/11] Use shared_pointer across the board. --- .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 33 +++++----- mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h | 64 ++++++++++++++++--- 2 files changed, 74 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h index 679d8d833877a..e3aa393febf3b 100644 --- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h @@ -41,13 +41,14 @@ struct XeCoreInfo { struct Xe2Plus : public uArch { XeCoreInfo xe_core; - Xe2Plus(const std::string &archName, const std::string &archDescription, - const XeCoreInfo &xeCore, - const std::vector &hierarchy = {}, - const std::map ®Info = {}, - const std::vector &cacheInfo = {}, - const std::map &instrs = {}, - const std::vector *> &restrs = {}) + Xe2Plus( + const std::string &archName, const std::string &archDescription, + const XeCoreInfo &xeCore, + const std::vector &hierarchy = {}, + const std::map ®Info = {}, + const std::vector &cacheInfo = {}, + const std::map> &instrs = {}, + const std::vector *> &restrs = {}) : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs, restrs), xe_core(xeCore) {} @@ -131,7 +132,7 @@ struct LoadStorePrefetch2DInstruction : public Instruction { namespace PVCuArch { struct PVCuArch : public Xe2Plus { // Maintaines ownership of the instructions owned by PVUarch - std::vector> owned_instructions; + std::vector> owned_instructions; PVCuArch() : Xe2Plus("pvc", // archName "Ponte Vecchio Architecture", // archDescription @@ -166,9 +167,10 @@ struct PVCuArch : public Xe2Plus { CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3])); // Add the instructions - auto dpas = std::make_unique(); - instructions[dpas->name] = dpas.get(); - owned_instructions.push_back(std::move(dpas)); + auto dpas = std::make_shared(); + instructions.emplace(dpas->name, dpas); + // instructions[dpas->name] = dpas.get(); + owned_instructions.push_back(dpas); } }; } // namespace PVCuArch @@ -176,7 +178,7 @@ struct PVCuArch : public Xe2Plus { namespace BMGuArch { struct BMGuArch : public Xe2Plus { // Maintaines ownership of the instructions owned by PVUarch - std::vector> owned_instructions; + std::vector> owned_instructions; BMGuArch() : Xe2Plus("bmg", // archName "Battlemage Architecture", // archDescription @@ -210,9 +212,10 @@ struct BMGuArch : public Xe2Plus { CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3])); // Add the instructions - auto dpas = std::make_unique(); - instructions[dpas->name] = dpas.get(); - owned_instructions.push_back(std::move(dpas)); + auto dpas = std::make_shared(); + instructions.emplace(dpas->name, dpas); + // instructions[dpas->name] = dpas.get(); + owned_instructions.push_back(dpas); } }; } // namespace BMGuArch diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h index cc04b18427ad8..35652cf669bdb 100644 --- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h @@ -203,6 +203,8 @@ struct Instruction { : name(std::move(name)), description(std::move(desc)), opcode(std::move(opcode)), functional_unit(fu), type(itype), scope(sc), unit_of_computation(uoc) {} + + virtual ~Instruction() = default; }; // A struct to represent register file information @@ -276,7 +278,7 @@ struct uArch { // Each level of cache is indexed lower to higher in the vector // (e.g., L1 indexed at 0, L2 at 1 and so on) L1, L2, L3, etc. std::vector cache_info; - std::map instructions; + std::map> instructions; std::vector *> restrictions; // Constructor @@ -285,7 +287,8 @@ struct uArch { const std::vector &uArch_hierarchy = {}, const std::map ®ister_file_info = {}, const std::vector &cache_info = {}, - const std::map &instructions = {}, + const std::map> + &instructions = {}, const std::vector *> &restrictions = {}) : name(name), description(description), uArch_hierarchy(uArch_hierarchy), register_file_info(register_file_info), cache_info(cache_info), @@ -380,19 +383,19 @@ struct uArchMap { } // Insert or update a key-value pair - void insert(const std::string &key, uArch value) { + void insert(const std::string &key, std::shared_ptr value) { std::unique_lock lock(mutex_); - // map_[key] = value; - map_.emplace(key, value); + // map_[key] = std::move(value); // safe to overwrite + map_.emplace(key, std::move(value)); // safe to overwrite } // Get a value by key (concurrent safe read) - std::optional get(const std::string &key) const { + std::shared_ptr get(const std::string &key) const { std::shared_lock lock(mutex_); auto it = map_.find(key); if (it != map_.end()) return it->second; - return std::nullopt; + return nullptr; } // Check if a key exists @@ -413,9 +416,54 @@ struct uArchMap { uArchMap &operator=(const uArchMap &) = delete; mutable std::shared_mutex mutex_; - std::map map_; + std::map> map_; }; +// struct uArchMap { +// public: +// // Singleton instance +// static uArchMap &instance() { +// static uArchMap instance; +// return instance; +// } + +// // Insert or update a key-value pair +// void insert(const std::string &key, uArch value) { +// std::unique_lock lock(mutex_); +// // map_[key] = value; +// map_.emplace(key, value); +// } + +// // Get a value by key (concurrent safe read) +// std::optional get(const std::string &key) const { +// std::shared_lock lock(mutex_); +// auto it = map_.find(key); +// if (it != map_.end()) +// return it->second; +// return std::nullopt; +// } + +// // Check if a key exists +// bool contains(const std::string &key) const { +// std::shared_lock lock(mutex_); +// return map_.find(key) != map_.end(); +// } + +// // Remove a key +// bool erase(const std::string &key) { +// std::unique_lock lock(mutex_); +// return map_.erase(key) > 0; +// } + +// private: +// uArchMap() = default; +// uArchMap(const uArchMap &) = delete; +// uArchMap &operator=(const uArchMap &) = delete; + +// mutable std::shared_mutex mutex_; +// std::map> map_; +// }; + // std::unordered_map uArchMap; // std::shared_mutex uArchMapMutex; From 9d4dfcaa4f462dba6337addeb60e15528164306d Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Wed, 2 Jul 2025 11:46:01 +0000 Subject: [PATCH 08/11] Add xegpu-attch-target-device pass. --- .../mlir/Dialect/XeGPU/Transforms/Passes.td | 17 ++++ .../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 + .../Transforms/XeGPUAttachTargetDevice.cpp | 80 +++++++++++++++++++ mlir/test/Dialect/XeGPU/uarch-info.mlir | 14 ++++ 4 files changed, 112 insertions(+) create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUAttachTargetDevice.cpp create mode 100644 mlir/test/Dialect/XeGPU/uarch-info.mlir diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index 3a88dae041dd1..07c368efb4273 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -71,4 +71,21 @@ def XeGPUBlocking: Pass<"xegpu-blocking"> { ]; } +def XeGPUAttachTargetDevice : Pass<"xegpu-attach-target-device", "ModuleOp"> { + let summary = "Attach a dlti.target_system_spec entry with a named device"; + let description = [{ + This pass attaches a `dlti.target_system_spec` attribute to the module + with a device entry like `#dlti.dl_entry<"name", "">`. + }]; + + let options = [ + Option<"deviceName", "device-name", "std::string", + /*default=*/"\"pvc\"", + "Name of the target device to attach (e.g. pvc)">, + ]; + let dependentDialects = [ + "xegpu::XeGPUDialect", "mlir::DLTIDialect" + ]; +} + #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index 9c178d1d85642..dbde19f11da33 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRXeGPUTransforms + XeGPUAttachTargetDevice.cpp XeGPUBlocking.cpp XeGPUFoldAliasOps.cpp XeGPUSubgroupDistribute.cpp diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUAttachTargetDevice.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUAttachTargetDevice.cpp new file mode 100644 index 0000000000000..d5fa7d446d273 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUAttachTargetDevice.cpp @@ -0,0 +1,80 @@ +//===-- XeGPUAttachTargetDevice.cpp ---- XeGPU Attach Target Device Pass --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace xegpu { +#define GEN_PASS_DEF_XEGPUATTACHTARGETDEVICE +#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" +} // namespace xegpu +} // namespace mlir + +using namespace mlir; + +namespace { +struct XeGPUAttachTargetDevicePass final + : public xegpu::impl::XeGPUAttachTargetDeviceBase< + XeGPUAttachTargetDevicePass> { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +void XeGPUAttachTargetDevicePass::runOnOperation() { + ModuleOp module = getOperation(); + MLIRContext *ctx = module.getContext(); + Builder b(ctx); + + // Build #dlti.dl_entry<"name", ""> + // auto nameEntry = dlti::DLEntryAttr::get(ctx, b.getStringAttr("name"), + // b.getStringAttr(deviceName)); + + auto nameEntry = DataLayoutEntryAttr::get(b.getStringAttr("name"), + b.getStringAttr(deviceName)); + + // Build #dlti.target_device_spec<...> + TargetDeviceSpecInterface deviceSpec = + TargetDeviceSpecAttr::get(ctx, {nameEntry}); + + // Construct a dl_entry for "GPU" = deviceSpec + auto sysSpecVal = + DataLayoutEntryAttr::get(b.getStringAttr("GPU"), deviceSpec); + + // Cast to the expected interface + DataLayoutEntryInterface sysSpecIface = + llvm::dyn_cast(sysSpecVal); + + // Now build target system spec + auto systemSpec = TargetSystemSpecAttr::get( + ctx, ArrayRef{sysSpecIface}); + + // Attach to module + module->setAttr("dlti.target_system_spec", systemSpec); + + // Create the uArch object for the target device and add it to the uArchMap + + if (deviceName == "pvc") { + auto pvcuArch = + std::make_shared(); + mlir::xegpu::uArch::uArchMap::instance().insert(deviceName, pvcuArch); + } else if (deviceName == "bmg") { + auto bmguArch = + std::make_shared(); + mlir::xegpu::uArch::uArchMap::instance().insert(deviceName, bmguArch); + } +} diff --git a/mlir/test/Dialect/XeGPU/uarch-info.mlir b/mlir/test/Dialect/XeGPU/uarch-info.mlir new file mode 100644 index 0000000000000..351c96f8169c5 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/uarch-info.mlir @@ -0,0 +1,14 @@ +module @eltwise_add attributes {gpu.container_module} { + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + + gpu.func @dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<32x24xf32, #xegpu.layout> -> vector<32x24xf32> + %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + gpu.return + } + } +} From 78677f47b4e8a9b80dcb364234198bf705018f61 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Wed, 2 Jul 2025 16:40:49 +0000 Subject: [PATCH 09/11] Add the usage of uArch in XeGPU verification. --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 41 ++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 2793c7a35bc97..80cd36f8328a7 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" @@ -577,6 +579,45 @@ LogicalResult DpasOp::verify() { if (getAcc() && getAcc().getType() != getResultType()) return emitOpError("Expecting the acc type to be the same as result."); + // @uArch: Check if the types are supported for DPAS. + Operation *op = getOperation(); + auto moduleOp = op->getParentOfType(); + if (!moduleOp) + llvm::errs() << "No parent module op.\n"; + + auto targetDeviceNameAttr = dlti::query(moduleOp, {"GPU", "name"}); + if (failed(targetDeviceNameAttr)) + llvm::errs() + << "No target device found, skipping target-specific verification\n"; + + // Potential usage of uArch in verification. + if (succeeded(targetDeviceNameAttr)) { + auto targetDeviceNameStr = + llvm::dyn_cast(targetDeviceNameAttr.value()).str(); + auto targetDeviceArch = + mlir::xegpu::uArch::uArchMap::instance().get(targetDeviceNameStr); + if (targetDeviceArch) { + // @TODO: We should keep the name of the Instructions in one place, since + // we use the name of the instruction to find the instruction, it should + // be standardized and kept for users to access. + auto it = targetDeviceArch->instructions.find("dpas"); + if (it != targetDeviceArch->instructions.end()) { + std::shared_ptr instr = it->second; + std::cout << "Found instruction: " << instr->name << std::endl; + auto matrixOp = + std::dynamic_pointer_cast( + instr); + if (matrixOp) { + if (!matrixOp->checkSupportedMMATypes( + getLhsType().getElementType(), getRhsType().getElementType(), + getResultType().getElementType(), + getResultType().getElementType())) + return emitOpError("Unsupported DPAS types."); + } + } + } + } + // SIMT code: the size of the B operand has to be a multiple of 32 bits. // It skips the semantic check since lack of architecture information. // Users need to ensure the correctness. From b05de857fd58a0e8ce53ec15a5f9222828433568 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Wed, 2 Jul 2025 18:23:54 +0000 Subject: [PATCH 10/11] Remove comments, add test cases. --- mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h | 64 ------------------- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 2 +- .../Transforms/XeGPUAttachTargetDevice.cpp | 3 +- .../Dialect/XeGPU/attach-target-device.mlir | 54 ++++++++++++++++ 4 files changed, 57 insertions(+), 66 deletions(-) create mode 100644 mlir/test/Dialect/XeGPU/attach-target-device.mlir diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h index 35652cf669bdb..0bd16eaa75d59 100644 --- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArch.h @@ -419,70 +419,6 @@ struct uArchMap { std::map> map_; }; -// struct uArchMap { -// public: -// // Singleton instance -// static uArchMap &instance() { -// static uArchMap instance; -// return instance; -// } - -// // Insert or update a key-value pair -// void insert(const std::string &key, uArch value) { -// std::unique_lock lock(mutex_); -// // map_[key] = value; -// map_.emplace(key, value); -// } - -// // Get a value by key (concurrent safe read) -// std::optional get(const std::string &key) const { -// std::shared_lock lock(mutex_); -// auto it = map_.find(key); -// if (it != map_.end()) -// return it->second; -// return std::nullopt; -// } - -// // Check if a key exists -// bool contains(const std::string &key) const { -// std::shared_lock lock(mutex_); -// return map_.find(key) != map_.end(); -// } - -// // Remove a key -// bool erase(const std::string &key) { -// std::unique_lock lock(mutex_); -// return map_.erase(key) > 0; -// } - -// private: -// uArchMap() = default; -// uArchMap(const uArchMap &) = delete; -// uArchMap &operator=(const uArchMap &) = delete; - -// mutable std::shared_mutex mutex_; -// std::map> map_; -// }; - -// std::unordered_map uArchMap; -// std::shared_mutex uArchMapMutex; - -// void getuArch(const std::string &key) { -// std::shared_lock lock(uArchMapMutex); -// auto it = uArchMap.find(key); -// if(it != uArchMap.end()) -// return *it; -// else - -// // safe concurrent read -// } - -// void AdduArch(const std::string &key, uArch &value) { -// std::unique_lock lock(uArchMapMutex); - -// // exclusive write -// } - } // namespace uArch } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 80cd36f8328a7..b8919b76ba3f7 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -585,6 +585,7 @@ LogicalResult DpasOp::verify() { if (!moduleOp) llvm::errs() << "No parent module op.\n"; + // It target device info is not attched, skip the target-specific checks auto targetDeviceNameAttr = dlti::query(moduleOp, {"GPU", "name"}); if (failed(targetDeviceNameAttr)) llvm::errs() @@ -603,7 +604,6 @@ LogicalResult DpasOp::verify() { auto it = targetDeviceArch->instructions.find("dpas"); if (it != targetDeviceArch->instructions.end()) { std::shared_ptr instr = it->second; - std::cout << "Found instruction: " << instr->name << std::endl; auto matrixOp = std::dynamic_pointer_cast( instr); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUAttachTargetDevice.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUAttachTargetDevice.cpp index d5fa7d446d273..ea1e24b8e28d2 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUAttachTargetDevice.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUAttachTargetDevice.cpp @@ -67,7 +67,8 @@ void XeGPUAttachTargetDevicePass::runOnOperation() { module->setAttr("dlti.target_system_spec", systemSpec); // Create the uArch object for the target device and add it to the uArchMap - + // We don't have to do it here, we can do it in the Dialect initialization + // phase, this is just showing one way of doing it if (deviceName == "pvc") { auto pvcuArch = std::make_shared(); diff --git a/mlir/test/Dialect/XeGPU/attach-target-device.mlir b/mlir/test/Dialect/XeGPU/attach-target-device.mlir new file mode 100644 index 0000000000000..df4553c2c2726 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/attach-target-device.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt --xegpu-attach-target-device="device-name=pvc" %s -split-input-file -verify-diagnostics + +// module @valid_dpas attributes {gpu.container_module} { +// gpu.module @valid_dpas attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + +// gpu.func @valid_dpas(%a: memref<24x32xf16>, %b: memref<32x24xf16>) { +// %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.layout> +// %load_a = xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<24x32xf16, #xegpu.layout> +// -> vector<24x32xf16> +// %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf16> -> !xegpu.tensor_desc<32x24xf16, #xegpu.layout> +// %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<32x24xf16, #xegpu.layout> -> vector<32x24xf16> + +// %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout} : vector<24x32xf16>, vector<32x24xf16> -> vector<24x24xf16> +// gpu.return +// } +// } +// } + + +// RUN: mlir-opt %s -my-pass | FileCheck %s + +// CHECK: module @valid_dpas +// CHECK-SAME: attributes {dlti.target_system_spec = #dlti.target_system_spec<"GPU" = #dlti.target_device_spec<"name" = "pvc">>, gpu.container_module} +module @valid_dpas attributes {gpu.container_module} { + // CHECK: gpu.module @valid_dpas + gpu.module @valid_dpas attributes {spirv.target_env = #spirv.target_env<#spirv.vce,api = OpenCL,#spirv.resource_limits<>>} { + // CHECK: gpu.func @valid_dpas + gpu.func @valid_dpas(%a: memref<24x32xf16>, %b: memref<32x24xf16>) { + // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG0:.*]]{{\[}}0, 0] + // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16 + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.layout> + + // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] + // CHECK-SAME: -> vector<24x32xf16> + %load_a = xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<24x32xf16, #xegpu.layout> -> vector<24x32xf16> + + // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG1:.*]]{{\[}}0, 0] + // CHECK-SAME: memref<32x24xf16> -> !xegpu.tensor_desc<32x24xf16 + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf16> -> !xegpu.tensor_desc<32x24xf16, #xegpu.layout> + + // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] + // CHECK-SAME: -> vector<32x24xf16> + %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<32x24xf16, #xegpu.layout> -> vector<32x24xf16> + + // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] + // CHECK-SAME: layout_result_0 = #xegpu.layout + // CHECK-SAME: : vector<24x32xf16>, vector<32x24xf16> -> vector<24x24xf16> + %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout} : vector<24x32xf16>, vector<32x24xf16> -> vector<24x24xf16> + + // CHECK: gpu.return + gpu.return + } + } +} From f4e33724d78e65370601e142a5cf34b6e96a1fde Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Wed, 2 Jul 2025 20:28:29 +0000 Subject: [PATCH 11/11] Move uArchMap population in XeGPUDialect Initialization phase. This way it's available always. --- mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 9 +++++++++ .../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 + mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt | 1 - mlir/test/Dialect/XeGPU/invalid.mlir | 19 +++++++++++++++++++ 5 files changed, 30 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt index 242a97ccfdf6d..5393b9b7b1c6f 100644 --- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRXeGPUDialect MLIRArithUtils MLIRDialectUtils MLIRIR + MLIRXeGPUuArch MLIRViewLikeInterface MLIRVectorDialect ) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 7ef61de190b4c..c198478a508c5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" @@ -32,6 +33,14 @@ void XeGPUDialect::initialize() { #define GET_ATTRDEF_LIST #include >(); + + // Populate the uArchMap with the supported target devices + auto pvcuArch = + std::make_shared(); + mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch); + auto bmguArch = + std::make_shared(); + mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch); } // Checks if the given shape can be evenly distributed based on the layout diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index dbde19f11da33..c88c43aa43941 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms MLIRTransforms MLIRGPUDialect MLIRXeGPUUtils + MLIRXeGPUuArch MLIRGPUUtils MLIRVectorTransforms ) diff --git a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt index b880f9abf04ac..c7f691cb6dda7 100644 --- a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt @@ -7,6 +7,5 @@ add_mlir_dialect_library(MLIRXeGPUuArch LINK_LIBS PUBLIC MLIRIR MLIRDialectUtils - MLIRXeGPUDialect ) diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index a2778cd94d963..f07bb8301a88f 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -646,3 +646,22 @@ func.func @tensor_desc_invalid_sg_data(%src: ui64, %offsets: vector<16xindex>) { #xegpu.layout> return } + + +// ----- +module @invalid_dpas attributes {dlti.target_system_spec = #dlti.target_system_spec<"GPU" = #dlti.target_device_spec<"name" = "pvc">>, gpu.container_module} { + gpu.module @invalid_dpas attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + + gpu.func @invalid_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<32x24xf32, #xegpu.layout> -> vector<32x24xf32> + // expected-error@+1 {{Unsupported DPAS types.}} + %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + gpu.return + } + } +} +