Skip to content

Commit 94b15a1

Browse files
authored
[mlir][spirv] Add basic support for SPV_EXT_replicated_composites (#147067)
This patch introduces two new ops to the SPIR-V dialect: - `spirv.EXT.ConstantCompositeReplicate` - `spirv.EXT.SpecConstantCompositeReplicate` These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to `SPV_EXT_replicated_composites` extension instructions: - `OpConstantCompositeReplicatedEXT` - `OpSpecConstantCompositeReplicatedEXT` No transformation to these new ops has been introduced in this patch. This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987 --------- Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian@arm.com>
1 parent 612afab commit 94b15a1

File tree

13 files changed

+694
-9
lines changed

13 files changed

+694
-9
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def SPV_EXT_shader_atomic_float_min_max : I32EnumAttrCase<"SPV_EXT_shader_atomi
359359
def SPV_EXT_shader_image_int64 : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
360360
def SPV_EXT_shader_atomic_float16_add : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
361361
def SPV_EXT_mesh_shader : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
362+
def SPV_EXT_replicated_composites : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
362363

363364
def SPV_AMD_gpu_shader_half_float_fetch : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
364365
def SPV_AMD_shader_ballot : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -446,7 +447,7 @@ def SPIRV_ExtensionAttr :
446447
SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
447448
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
448449
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
449-
SPV_EXT_mesh_shader,
450+
SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
450451
SPV_ARM_tensors,
451452
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
452453
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -849,6 +850,12 @@ def SPIRV_C_CooperativeMatrixKHR : I32EnumAttrCase<"Coope
849850
MinVersion<SPIRV_V_1_6>
850851
];
851852
}
853+
def SPIRV_C_ReplicatedCompositesEXT : I32EnumAttrCase<"ReplicatedCompositesEXT", 6024> {
854+
list<Availability> availability = [
855+
Extension<[SPV_EXT_replicated_composites]>,
856+
MinVersion<SPIRV_V_1_0>
857+
];
858+
}
852859
def SPIRV_C_BitInstructions : I32EnumAttrCase<"BitInstructions", 6025> {
853860
list<Availability> availability = [
854861
Extension<[SPV_KHR_bit_instructions]>
@@ -1500,7 +1507,7 @@ def SPIRV_CapabilityAttr :
15001507
SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
15011508
SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
15021509
SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
1503-
SPIRV_C_CooperativeMatrixKHR,
1510+
SPIRV_C_CooperativeMatrixKHR, SPIRV_C_ReplicatedCompositesEXT,
15041511
SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
15051512
SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
15061513
SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4565,6 +4572,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMa
45654572
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
45664573
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
45674574
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
4575+
def SPIRV_OC_OpConstantCompositeReplicateEXT : I32EnumAttrCase<"OpConstantCompositeReplicateEXT", 4461>;
4576+
def SPIRV_OC_OpSpecConstantCompositeReplicateEXT : I32EnumAttrCase<"OpSpecConstantCompositeReplicateEXT", 4462>;
45684577
def SPIRV_OC_OpEmitMeshTasksEXT : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
45694578
def SPIRV_OC_OpSetMeshOutputsEXT : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
45704579
def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
@@ -4674,6 +4683,8 @@ def SPIRV_OpcodeAttr :
46744683
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
46754684
SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
46764685
SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
4686+
SPIRV_OC_OpConstantCompositeReplicateEXT,
4687+
SPIRV_OC_OpSpecConstantCompositeReplicateEXT,
46774688
SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
46784689
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
46794690
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,47 @@ def SPIRV_ConstantOp : SPIRV_Op<"Constant",
135135
let autogenSerialization = 0;
136136
}
137137

138+
139+
// -----
140+
141+
def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantCompositeReplicate", [Pure]> {
142+
let summary = [{
143+
Declare a new replicated composite constant op.
144+
}];
145+
146+
let description = [{
147+
Represents a splat composite constant i.e., all elements of composite constant
148+
have the same value.
149+
150+
#### Example:
151+
152+
```mlir
153+
%0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
154+
%1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<2 x vector<2xi32>>
155+
%2 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
156+
```
157+
}];
158+
159+
let availability = [
160+
MinVersion<SPIRV_V_1_0>,
161+
MaxVersion<SPIRV_V_1_6>,
162+
Extension<[SPV_EXT_replicated_composites]>,
163+
Capability<[SPIRV_C_ReplicatedCompositesEXT]>
164+
];
165+
166+
let arguments = (ins
167+
AnyAttr:$value
168+
);
169+
170+
let results = (outs
171+
SPIRV_Composite:$replicated_constant
172+
);
173+
174+
let autogenSerialization = 0;
175+
176+
let assemblyFormat = "` ` `[` $value `]` `:` type($replicated_constant) attr-dict";
177+
}
178+
138179
// -----
139180

140181
def SPIRV_EntryPointOp : SPIRV_Op<"EntryPoint", [InModuleScope]> {
@@ -689,6 +730,43 @@ def SPIRV_SpecConstantCompositeOp : SPIRV_Op<"SpecConstantComposite", [
689730

690731
// -----
691732

733+
def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantCompositeReplicate", [InModuleScope, Symbol]> {
734+
let summary = "Declare a new replicated composite specialization constant op.";
735+
736+
let description = [{
737+
Represents a splat spec composite constant i.e., all elements of spec composite
738+
constant have the same value. The splat value must come from a symbol reference
739+
of spec constant instruction.
740+
741+
#### Example:
742+
743+
```mlir
744+
spirv.SpecConstant @sc_i32_1 = 1 : i32
745+
spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
746+
spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
747+
```
748+
}];
749+
750+
let availability = [
751+
MinVersion<SPIRV_V_1_0>,
752+
MaxVersion<SPIRV_V_1_6>,
753+
Extension<[SPV_EXT_replicated_composites]>,
754+
Capability<[SPIRV_C_ReplicatedCompositesEXT]>
755+
];
756+
757+
let arguments = (ins
758+
TypeAttr:$type,
759+
StrAttr:$sym_name,
760+
SymbolRefAttr:$constituent
761+
);
762+
763+
let results = (outs);
764+
765+
let autogenSerialization = 0;
766+
}
767+
768+
// -----
769+
692770
def SPIRV_SpecConstantOperationOp : SPIRV_Op<"SpecConstantOperation", [
693771
Pure, InFunctionScope,
694772
SingleBlockImplicitTerminator<"YieldOp">]> {

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,44 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
763763
setNameFn(getResult(), specialName.str());
764764
}
765765

766+
//===----------------------------------------------------------------------===//
767+
// spirv.EXTConstantCompositeReplicate
768+
//===----------------------------------------------------------------------===//
769+
770+
LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
771+
Type valueType;
772+
if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
773+
valueType = typedAttr.getType();
774+
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
775+
auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
776+
if (!typedElemAttr)
777+
return emitError("value attribute is not typed");
778+
valueType =
779+
spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
780+
} else {
781+
return emitError("unknown value attribute type");
782+
}
783+
784+
auto compositeType = dyn_cast<spirv::CompositeType>(getType());
785+
if (!compositeType)
786+
return emitError("result type is not a composite type");
787+
788+
Type compositeElementType = compositeType.getElementType(0);
789+
790+
SmallVector<Type, 3> possibleTypes = {compositeElementType};
791+
while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
792+
compositeElementType = type.getElementType(0);
793+
possibleTypes.push_back(compositeElementType);
794+
}
795+
796+
if (!is_contained(possibleTypes, valueType)) {
797+
return emitError("expected value attribute type ")
798+
<< interleaved(possibleTypes, " or ") << ", but got: " << valueType;
799+
}
800+
801+
return success();
802+
}
803+
766804
//===----------------------------------------------------------------------===//
767805
// spirv.ControlBarrierOp
768806
//===----------------------------------------------------------------------===//
@@ -1864,6 +1902,69 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
18641902
return success();
18651903
}
18661904

1905+
//===----------------------------------------------------------------------===//
1906+
// spirv.EXTSpecConstantCompositeReplicateOp
1907+
//===----------------------------------------------------------------------===//
1908+
1909+
ParseResult
1910+
spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
1911+
OperationState &result) {
1912+
StringAttr compositeName;
1913+
FlatSymbolRefAttr specConstRef;
1914+
const char *attrName = "spec_const";
1915+
NamedAttrList attrs;
1916+
Type type;
1917+
1918+
if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1919+
result.attributes) ||
1920+
parser.parseLParen() ||
1921+
parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1922+
parser.parseRParen() || parser.parseColonType(type))
1923+
return failure();
1924+
1925+
StringAttr compositeSpecConstituentName =
1926+
spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1927+
result.name);
1928+
result.addAttribute(compositeSpecConstituentName, specConstRef);
1929+
1930+
StringAttr typeAttrName =
1931+
spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1932+
result.addAttribute(typeAttrName, TypeAttr::get(type));
1933+
1934+
return success();
1935+
}
1936+
1937+
void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
1938+
printer << " ";
1939+
printer.printSymbolName(getSymName());
1940+
printer << " (" << this->getConstituent() << ") : " << getType();
1941+
}
1942+
1943+
LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1944+
auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1945+
if (!compositeType)
1946+
return emitError("result type must be a composite type, but provided ")
1947+
<< getType();
1948+
1949+
Operation *constituentOp = SymbolTable::lookupNearestSymbolFrom(
1950+
(*this)->getParentOp(), this->getConstituent());
1951+
if (!constituentOp)
1952+
return emitError(
1953+
"splat spec constant reference defining constituent not found");
1954+
1955+
auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1956+
if (!constituentSpecConstOp)
1957+
return emitError("constituent is not a spec constant");
1958+
1959+
Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1960+
Type compositeElementType = compositeType.getElementType(0);
1961+
if (constituentType != compositeElementType)
1962+
return emitError("constituent has incorrect type: expected ")
1963+
<< compositeElementType << ", but provided " << constituentType;
1964+
1965+
return success();
1966+
}
1967+
18671968
//===----------------------------------------------------------------------===//
18681969
// spirv.SpecConstantOperation
18691970
//===----------------------------------------------------------------------===//

mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
4545
return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
4646
constInfo->first);
4747
}
48+
if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
49+
getConstantCompositeReplicate(id)) {
50+
return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
51+
unknownLoc, constCompositeReplicateInfo->second,
52+
constCompositeReplicateInfo->first);
53+
}
4854
if (auto varOp = getGlobalVariable(id)) {
4955
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
5056
unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
@@ -56,10 +62,18 @@ Value spirv::Deserializer::getValue(uint32_t id) {
5662
SymbolRefAttr::get(constOp.getOperation()));
5763
return referenceOfOp.getReference();
5864
}
59-
if (auto constCompositeOp = getSpecConstantComposite(id)) {
65+
if (SpecConstantCompositeOp specConstCompositeOp =
66+
getSpecConstantComposite(id)) {
67+
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
68+
unknownLoc, specConstCompositeOp.getType(),
69+
SymbolRefAttr::get(specConstCompositeOp.getOperation()));
70+
return referenceOfOp.getReference();
71+
}
72+
if (auto specConstCompositeReplicateOp =
73+
getSpecConstantCompositeReplicate(id)) {
6074
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
61-
unknownLoc, constCompositeOp.getType(),
62-
SymbolRefAttr::get(constCompositeOp.getOperation()));
75+
unknownLoc, specConstCompositeReplicateOp.getType(),
76+
SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
6377
return referenceOfOp.getReference();
6478
}
6579
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
@@ -175,8 +189,12 @@ LogicalResult spirv::Deserializer::processInstruction(
175189
return processConstant(operands, /*isSpec=*/true);
176190
case spirv::Opcode::OpConstantComposite:
177191
return processConstantComposite(operands);
192+
case spirv::Opcode::OpConstantCompositeReplicateEXT:
193+
return processConstantCompositeReplicateEXT(operands);
178194
case spirv::Opcode::OpSpecConstantComposite:
179195
return processSpecConstantComposite(operands);
196+
case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
197+
return processSpecConstantCompositeReplicateEXT(operands);
180198
case spirv::Opcode::OpSpecConstantOp:
181199
return processSpecConstantOperation(operands);
182200
case spirv::Opcode::OpConstantTrue:

0 commit comments

Comments
 (0)