Skip to content

Commit bf016b9

Browse files
[mlir][spirv] Add support for SPV_ARM_tensors (#144667)
This patch introduces a new custom type `!spirv.arm.tensor<>` to the MLIR SPIR-V dialect to represent `OpTypeTensorARM` as defined in the `SPV_ARM_tensors` extension. The type models a shaped tensor with element type and optional shape, and implements the `ShapedType` interface to enable reuse of MLIR's existing shape-aware infrastructure. The type supports serialization to and from SPIR-V binary as `OpTypeTensorARM`, and emits the required capability (`TensorsARM`) and extension (`SPV_ARM_tensors`) declarations automatically. This addition lays the foundation for supporting structured tensor values natively in SPIR-V and will enable future support for operations defined in the `SPV_ARM_tensors` extension, such as `OpTensorReadARM`, `OpTensorWriteARM`, and `OpTensorQuerySizeARM`. Reference: KhronosGroup/SPIRV-Registry#342 --------- Signed-off-by: Davide Grohmann <davide.grohmann@arm.com> Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian@arm.com>
1 parent 1626867 commit bf016b9

File tree

11 files changed

+481
-10
lines changed

11 files changed

+481
-10
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
422422

423423
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
424424

425+
def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
426+
425427
def SPIRV_ExtensionAttr :
426428
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
427429
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
@@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr :
445447
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
446448
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
447449
SPV_EXT_mesh_shader,
450+
SPV_ARM_tensors,
448451
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
449452
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
450453
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams : I32EnumAttrCase<"Geome
13111314
def SPIRV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57> {
13121315
list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
13131316
}
1317+
def SPIRV_C_TensorsARM : I32EnumAttrCase<"TensorsARM", 4174> {
1318+
list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
1319+
list<Availability> availability = [
1320+
Extension<[SPV_ARM_tensors]>
1321+
];
1322+
}
1323+
def SPIRV_C_StorageTensorArrayDynamicIndexingEXT : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
1324+
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
1325+
list<Availability> availability = [
1326+
Extension<[SPV_ARM_tensors]>
1327+
];
1328+
}
1329+
def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
1330+
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
1331+
list<Availability> availability = [
1332+
Extension<[SPV_ARM_tensors]>
1333+
];
1334+
}
13141335
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
13151336
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
13161337
list<Availability> availability = [
@@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr :
15231544
SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
15241545
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
15251546
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
1547+
SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
1548+
SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
15261549
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
15271550
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
15281551
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
41794202
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
41804203
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
41814204
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
4182-
4205+
def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
41834206

41844207
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
41854208
// for the definition of the following types and type categories.
@@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
42174240
"any SPIR-V struct type">;
42184241
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
42194242
"any SPIR-V sampled image type">;
4243+
def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
4244+
"any SPIR-V tensorArm type">;
42204245

42214246
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
42224247
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
@@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[
42284253
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
42294254
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
42304255
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
4231-
SPIRV_AnyImage
4256+
SPIRV_AnyImage, SPIRV_AnyTensorArm
42324257
]>;
42334258

42344259
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
45254550
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
45264551
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
45274552
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
4553+
def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
45284554
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
45294555
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
45304556
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr :
46384664
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
46394665
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
46404666
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
4641-
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
4667+
SPIRV_OC_OpGroupNonUniformLogicalXor,
4668+
SPIRV_OC_OpTypeTensorARM,
4669+
SPIRV_OC_OpSubgroupBallotKHR,
46424670
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
46434671
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
46444672
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace spirv {
2929
namespace detail {
3030
struct ArrayTypeStorage;
3131
struct CooperativeMatrixTypeStorage;
32+
struct TensorArmTypeStorage;
3233
struct ImageTypeStorage;
3334
struct MatrixTypeStorage;
3435
struct PointerTypeStorage;
@@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
9697
std::optional<int64_t> getSizeInBytes();
9798
};
9899

99-
// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
100+
// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
101+
// StructType, or SPIR-V TensorArmType.
100102
class CompositeType : public SPIRVType {
101103
public:
102104
using SPIRVType::SPIRVType;
@@ -477,6 +479,46 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
477479
std::optional<StorageClass> storage = std::nullopt);
478480
};
479481

482+
/// SPIR-V TensorARM Type
483+
class TensorArmType
484+
: public Type::TypeBase<TensorArmType, CompositeType,
485+
detail::TensorArmTypeStorage, ShapedType::Trait> {
486+
public:
487+
using Base::Base;
488+
489+
using ShapedTypeTraits = ShapedType::Trait<TensorArmType>;
490+
using ShapedTypeTraits::getDimSize;
491+
using ShapedTypeTraits::getDynamicDimIndex;
492+
using ShapedTypeTraits::getElementTypeBitWidth;
493+
using ShapedTypeTraits::getNumDynamicDims;
494+
using ShapedTypeTraits::getNumElements;
495+
using ShapedTypeTraits::getRank;
496+
using ShapedTypeTraits::hasStaticShape;
497+
using ShapedTypeTraits::isDynamicDim;
498+
499+
static constexpr StringLiteral name = "spirv.arm.tensor";
500+
501+
// TensorArm supports minimum rank of 1, hence an empty shape here means
502+
// unranked.
503+
static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
504+
TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
505+
Type elementType) const;
506+
507+
static LogicalResult
508+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
509+
ArrayRef<int64_t> shape, Type elementType);
510+
511+
Type getElementType() const;
512+
ArrayRef<int64_t> getShape() const;
513+
bool hasRank() const { return !getShape().empty(); }
514+
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
515+
516+
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
517+
std::optional<StorageClass> storage = std::nullopt);
518+
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
519+
std::optional<StorageClass> storage = std::nullopt);
520+
};
521+
480522
} // namespace spirv
481523
} // namespace mlir
482524

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
194194
<< t.getNumElements();
195195
return Type();
196196
}
197+
} else if (auto t = dyn_cast<TensorArmType>(type)) {
198+
if (!isa<ScalarType>(t.getElementType())) {
199+
parser.emitError(
200+
typeLoc, "only scalar element type allowed in tensor type but found ")
201+
<< t.getElementType();
202+
return Type();
203+
}
197204
} else {
198205
parser.emitError(typeLoc, "cannot use ")
199206
<< type << " to compose SPIR-V types";
@@ -363,6 +370,52 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
363370
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
364371
}
365372

373+
// tensor-arm-type ::=
374+
// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
375+
static Type parseTensorArmType(SPIRVDialect const &dialect,
376+
DialectAsmParser &parser) {
377+
if (parser.parseLess())
378+
return {};
379+
380+
bool unranked = false;
381+
SmallVector<int64_t, 4> dims;
382+
SMLoc countLoc = parser.getCurrentLocation();
383+
384+
if (parser.parseOptionalStar().succeeded()) {
385+
unranked = true;
386+
if (parser.parseXInDimensionList())
387+
return {};
388+
} else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
389+
return {};
390+
}
391+
392+
if (!unranked && dims.empty()) {
393+
parser.emitError(countLoc, "arm.tensors do not support rank zero");
394+
return {};
395+
}
396+
397+
if (llvm::is_contained(dims, 0)) {
398+
parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
399+
return {};
400+
}
401+
402+
if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
403+
llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
404+
parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
405+
"fully dynamic or completed shaped");
406+
return {};
407+
}
408+
409+
auto elementTy = parseAndVerifyType(dialect, parser);
410+
if (!elementTy)
411+
return {};
412+
413+
if (parser.parseGreater())
414+
return {};
415+
416+
return TensorArmType::get(dims, elementTy);
417+
}
418+
366419
// TODO: Reorder methods to be utilities first and parse*Type
367420
// methods in alphabetical order
368421
//
@@ -759,6 +812,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
759812
return parseStructType(*this, parser);
760813
if (keyword == "matrix")
761814
return parseMatrixType(*this, parser);
815+
if (keyword == "arm.tensor")
816+
return parseTensorArmType(*this, parser);
762817
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
763818
return Type();
764819
}
@@ -855,10 +910,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
855910
os << ">";
856911
}
857912

913+
static void print(TensorArmType type, DialectAsmPrinter &os) {
914+
os << "arm.tensor<";
915+
916+
llvm::interleave(
917+
type.getShape(), os,
918+
[&](int64_t dim) {
919+
if (ShapedType::isDynamic(dim))
920+
os << '?';
921+
else
922+
os << dim;
923+
},
924+
"x");
925+
if (!type.hasRank()) {
926+
os << "*";
927+
}
928+
os << "x" << type.getElementType() << ">";
929+
}
930+
858931
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
859932
TypeSwitch<Type>(type)
860933
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
861-
ImageType, SampledImageType, StructType, MatrixType>(
934+
ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
862935
[&](auto type) { print(type, os); })
863936
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
864937
}

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
547547
return failure();
548548
}
549549

550+
if (llvm::isa<TensorArmType>(type)) {
551+
if (parser.parseOptionalColon().succeeded())
552+
if (parser.parseType(type))
553+
return failure();
554+
}
555+
550556
return parser.addTypeToList(type, result.types);
551557
}
552558

0 commit comments

Comments
 (0)