Skip to content

Commit 3402f02

Browse files
Firestar99eddyb
authored andcommitted
update rspirv: add new ops and OpKinds
1 parent edc8770 commit 3402f02

File tree

2 files changed

+264
-4
lines changed

2 files changed

+264
-4
lines changed

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@ use super::Builder;
22
use crate::builder_spirv::{BuilderCursor, SpirvValue};
33
use crate::codegen_cx::CodegenCx;
44
use crate::spirv_type::SpirvType;
5-
use num_traits::FromPrimitive;
65
use rspirv::dr;
76
use rspirv::grammar::{reflect, LogicalOperand, OperandKind, OperandQuantifier};
87
use rspirv::spirv::{
9-
FPFastMathMode, FragmentShadingRate, FunctionControl, GroupOperation, ImageOperands,
10-
KernelProfilingInfo, LoopControl, MemoryAccess, MemorySemantics, Op, RayFlags,
11-
SelectionControl, StorageClass, Word,
8+
CooperativeMatrixOperands, FPFastMathMode, FragmentShadingRate, FunctionControl,
9+
GroupOperation, ImageOperands, KernelProfilingInfo, LoopControl, MemoryAccess, MemorySemantics,
10+
Op, RayFlags, SelectionControl, StorageClass, Word,
1211
};
1312
use rustc_ast::ast::{InlineAsmOptions, InlineAsmTemplatePiece};
1413
use rustc_codegen_ssa::mir::place::PlaceRef;
@@ -1386,6 +1385,61 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
13861385
.push(dr::Operand::RayQueryCandidateIntersectionType(x)),
13871386
Err(()) => self.err(format!("unknown RayQueryCandidateIntersectionType {word}")),
13881387
},
1388+
(OperandKind::FPDenormMode, Some(word)) => match word.parse() {
1389+
Ok(x) => inst.operands.push(dr::Operand::FPDenormMode(x)),
1390+
Err(()) => self.err(format!("unknown FPDenormMode {word}")),
1391+
},
1392+
(OperandKind::QuantizationModes, Some(word)) => match word.parse() {
1393+
Ok(x) => inst.operands.push(dr::Operand::QuantizationModes(x)),
1394+
Err(()) => self.err(format!("unknown QuantizationModes {word}")),
1395+
},
1396+
(OperandKind::FPOperationMode, Some(word)) => match word.parse() {
1397+
Ok(x) => inst.operands.push(dr::Operand::FPOperationMode(x)),
1398+
Err(()) => self.err(format!("unknown FPOperationMode {word}")),
1399+
},
1400+
(OperandKind::OverflowModes, Some(word)) => match word.parse() {
1401+
Ok(x) => inst.operands.push(dr::Operand::OverflowModes(x)),
1402+
Err(()) => self.err(format!("unknown OverflowModes {word}")),
1403+
},
1404+
(OperandKind::PackedVectorFormat, Some(word)) => match word.parse() {
1405+
Ok(x) => inst.operands.push(dr::Operand::PackedVectorFormat(x)),
1406+
Err(()) => self.err(format!("unknown PackedVectorFormat {word}")),
1407+
},
1408+
(OperandKind::HostAccessQualifier, Some(word)) => match word.parse() {
1409+
Ok(x) => inst.operands.push(dr::Operand::HostAccessQualifier(x)),
1410+
Err(()) => self.err(format!("unknown HostAccessQualifier {word}")),
1411+
},
1412+
(OperandKind::CooperativeMatrixOperands, Some(word)) => {
1413+
match parse_bitflags_operand(COOPERATIVE_MATRIX_OPERANDS, word) {
1414+
Some(x) => inst
1415+
.operands
1416+
.push(dr::Operand::CooperativeMatrixOperands(x)),
1417+
None => self.err(format!("Unknown CooperativeMatrixOperands {word}")),
1418+
}
1419+
}
1420+
(OperandKind::CooperativeMatrixLayout, Some(word)) => match word.parse() {
1421+
Ok(x) => inst.operands.push(dr::Operand::CooperativeMatrixLayout(x)),
1422+
Err(()) => self.err(format!("unknown CooperativeMatrixLayout {word}")),
1423+
},
1424+
(OperandKind::CooperativeMatrixUse, Some(word)) => match word.parse() {
1425+
Ok(x) => inst.operands.push(dr::Operand::CooperativeMatrixUse(x)),
1426+
Err(()) => self.err(format!("unknown CooperativeMatrixUse {word}")),
1427+
},
1428+
(OperandKind::InitializationModeQualifier, Some(word)) => match word.parse() {
1429+
Ok(x) => inst
1430+
.operands
1431+
.push(dr::Operand::InitializationModeQualifier(x)),
1432+
Err(()) => self.err(format!("unknown InitializationModeQualifier {word}")),
1433+
},
1434+
(OperandKind::LoadCacheControl, Some(word)) => match word.parse() {
1435+
Ok(x) => inst.operands.push(dr::Operand::LoadCacheControl(x)),
1436+
Err(()) => self.err(format!("unknown LoadCacheControl {word}")),
1437+
},
1438+
(OperandKind::StoreCacheControl, Some(word)) => match word.parse() {
1439+
Ok(x) => inst.operands.push(dr::Operand::StoreCacheControl(x)),
1440+
Err(()) => self.err(format!("unknown StoreCacheControl {word}")),
1441+
},
1442+
(OperandKind::LiteralFloat, Some(word)) => todo!(),
13891443
(kind, None) => match token {
13901444
Token::Word(_) => bug!(),
13911445
Token::String(_) => {
@@ -1557,6 +1611,29 @@ pub const FRAGMENT_SHADING_RATE: &[(&str, FragmentShadingRate)] = &[
15571611
FragmentShadingRate::HORIZONTAL4_PIXELS,
15581612
),
15591613
];
1614+
pub const COOPERATIVE_MATRIX_OPERANDS: &[(&str, CooperativeMatrixOperands)] = &[
1615+
("NONE_KHR", CooperativeMatrixOperands::NONE_KHR),
1616+
(
1617+
"MATRIX_A_SIGNED_COMPONENTS_KHR",
1618+
CooperativeMatrixOperands::MATRIX_A_SIGNED_COMPONENTS_KHR,
1619+
),
1620+
(
1621+
"MATRIX_B_SIGNED_COMPONENTS_KHR",
1622+
CooperativeMatrixOperands::MATRIX_B_SIGNED_COMPONENTS_KHR,
1623+
),
1624+
(
1625+
"MATRIX_C_SIGNED_COMPONENTS_KHR",
1626+
CooperativeMatrixOperands::MATRIX_C_SIGNED_COMPONENTS_KHR,
1627+
),
1628+
(
1629+
"MATRIX_RESULT_SIGNED_COMPONENTS_KHR",
1630+
CooperativeMatrixOperands::MATRIX_RESULT_SIGNED_COMPONENTS_KHR,
1631+
),
1632+
(
1633+
"SATURATING_ACCUMULATION_KHR",
1634+
CooperativeMatrixOperands::SATURATING_ACCUMULATION_KHR,
1635+
),
1636+
];
15601637

15611638
fn parse_bitflags_operand<T: std::ops::BitOr<Output = T> + Copy>(
15621639
values: &'static [(&'static str, T)],

crates/rustc_codegen_spirv/src/spirv_type_constraints.rs

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,12 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> {
494494
Op::IAddCarry | Op::ISubBorrow | Op::UMulExtended | Op::SMulExtended => sig! {
495495
(T, T) -> Struct([T, T])
496496
},
497+
Op::SDot | Op::UDot | Op::SUDot | Op::SDotAccSat | Op::UDotAccSat | Op::SUDotAccSat => {
498+
sig! {
499+
// FIXME(eddyb) missing equality constraint between two vectors
500+
(Vector(T), T) -> Vector(T)
501+
}
502+
}
497503

498504
// 3.37.14. Bit Instructions
499505
Op::ShiftRightLogical
@@ -593,6 +599,8 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> {
593599
},
594600
// Capability: Kernel
595601
Op::AtomicFlagTestAndSet | Op::AtomicFlagClear => {}
602+
// SPV_EXT_shader_atomic_float_min_max
603+
Op::AtomicFMinEXT | Op::AtomicFMaxEXT => sig! { (Pointer(_, T), _, _, T) -> T },
596604
// SPV_EXT_shader_atomic_float_add
597605
Op::AtomicFAddEXT => sig! { (Pointer(_, T), _, _, T) -> T },
598606

@@ -945,6 +953,181 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> {
945953
| Op::SubgroupAvcSicGetInterRawSadsINTEL => {
946954
reserved!(SPV_INTEL_device_side_avc_motion_estimation);
947955
}
956+
// SPV_EXT_mesh_shader
957+
Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => {
958+
reserved!(SPV_EXT_mesh_shader)
959+
}
960+
// SPV_NV_ray_tracing_motion_blur
961+
Op::TraceMotionNV | Op::TraceRayMotionNV => reserved!(SPV_NV_ray_tracing_motion_blur),
962+
// SPV_NV_bindless_texture
963+
Op::ConvertUToImageNV
964+
| Op::ConvertUToSamplerNV
965+
| Op::ConvertImageToUNV
966+
| Op::ConvertSamplerToUNV
967+
| Op::ConvertUToSampledImageNV
968+
| Op::ConvertSampledImageToUNV
969+
| Op::SamplerImageAddressingModeNV => reserved!(SPV_NV_bindless_texture),
970+
// SPV_INTEL_inline_assembly
971+
Op::AsmTargetINTEL | Op::AsmINTEL | Op::AsmCallINTEL => reserved!(SPV_NV_bindless_texture),
972+
// SPV_INTEL_variable_length_array
973+
Op::VariableLengthArrayINTEL => reserved!(SPV_INTEL_variable_length_array),
974+
// SPV_KHR_uniform_group_instructions
975+
Op::GroupIMulKHR
976+
| Op::GroupFMulKHR
977+
| Op::GroupBitwiseAndKHR
978+
| Op::GroupBitwiseOrKHR
979+
| Op::GroupBitwiseXorKHR
980+
| Op::GroupLogicalAndKHR
981+
| Op::GroupLogicalOrKHR
982+
| Op::GroupLogicalXorKHR => reserved!(SPV_KHR_uniform_group_instructions),
983+
// SPV_KHR_expect_assume
984+
Op::AssumeTrueKHR | Op::ExpectKHR => reserved!(SPV_KHR_expect_assume),
985+
// SPV_KHR_subgroup_rotate
986+
Op::GroupNonUniformRotateKHR => reserved!(SPV_KHR_subgroup_rotate),
987+
// SPV_NV_shader_invocation_reorder
988+
Op::HitObjectRecordHitMotionNV
989+
| Op::HitObjectRecordHitWithIndexMotionNV
990+
| Op::HitObjectRecordMissMotionNV
991+
| Op::HitObjectGetWorldToObjectNV
992+
| Op::HitObjectGetObjectToWorldNV
993+
| Op::HitObjectGetObjectRayDirectionNV
994+
| Op::HitObjectGetObjectRayOriginNV
995+
| Op::HitObjectTraceRayMotionNV
996+
| Op::HitObjectGetShaderRecordBufferHandleNV
997+
| Op::HitObjectGetShaderBindingTableRecordIndexNV
998+
| Op::HitObjectRecordEmptyNV
999+
| Op::HitObjectTraceRayNV
1000+
| Op::HitObjectRecordHitNV
1001+
| Op::HitObjectRecordHitWithIndexNV
1002+
| Op::HitObjectRecordMissNV
1003+
| Op::HitObjectExecuteShaderNV
1004+
| Op::HitObjectGetCurrentTimeNV
1005+
| Op::HitObjectGetAttributesNV
1006+
| Op::HitObjectGetHitKindNV
1007+
| Op::HitObjectGetPrimitiveIndexNV
1008+
| Op::HitObjectGetGeometryIndexNV
1009+
| Op::HitObjectGetInstanceIdNV
1010+
| Op::HitObjectGetInstanceCustomIndexNV
1011+
| Op::HitObjectGetWorldRayDirectionNV
1012+
| Op::HitObjectGetWorldRayOriginNV
1013+
| Op::HitObjectGetRayTMaxNV
1014+
| Op::HitObjectGetRayTMinNV
1015+
| Op::HitObjectIsEmptyNV
1016+
| Op::HitObjectIsHitNV
1017+
| Op::HitObjectIsMissNV
1018+
| Op::ReorderThreadWithHitObjectNV
1019+
| Op::ReorderThreadWithHintNV
1020+
| Op::TypeHitObjectNV => reserved!(SPV_NV_shader_invocation_reorder),
1021+
// SPV_INTEL_arbitrary_precision_floating_point
1022+
Op::ArbitraryFloatAddINTEL
1023+
| Op::ArbitraryFloatSubINTEL
1024+
| Op::ArbitraryFloatMulINTEL
1025+
| Op::ArbitraryFloatDivINTEL
1026+
| Op::ArbitraryFloatGTINTEL
1027+
| Op::ArbitraryFloatGEINTEL
1028+
| Op::ArbitraryFloatLTINTEL
1029+
| Op::ArbitraryFloatLEINTEL
1030+
| Op::ArbitraryFloatEQINTEL
1031+
| Op::ArbitraryFloatRecipINTEL
1032+
| Op::ArbitraryFloatRSqrtINTEL
1033+
| Op::ArbitraryFloatCbrtINTEL
1034+
| Op::ArbitraryFloatHypotINTEL
1035+
| Op::ArbitraryFloatSqrtINTEL
1036+
| Op::ArbitraryFloatLogINTEL
1037+
| Op::ArbitraryFloatLog2INTEL
1038+
| Op::ArbitraryFloatLog10INTEL
1039+
| Op::ArbitraryFloatLog1pINTEL
1040+
| Op::ArbitraryFloatExpINTEL
1041+
| Op::ArbitraryFloatExp2INTEL
1042+
| Op::ArbitraryFloatExp10INTEL
1043+
| Op::ArbitraryFloatExpm1INTEL
1044+
| Op::ArbitraryFloatSinINTEL
1045+
| Op::ArbitraryFloatCosINTEL
1046+
| Op::ArbitraryFloatSinCosINTEL
1047+
| Op::ArbitraryFloatSinPiINTEL
1048+
| Op::ArbitraryFloatCosPiINTEL
1049+
| Op::ArbitraryFloatSinCosPiINTEL
1050+
| Op::ArbitraryFloatASinINTEL
1051+
| Op::ArbitraryFloatASinPiINTEL
1052+
| Op::ArbitraryFloatACosINTEL
1053+
| Op::ArbitraryFloatACosPiINTEL
1054+
| Op::ArbitraryFloatATanINTEL
1055+
| Op::ArbitraryFloatATanPiINTEL
1056+
| Op::ArbitraryFloatATan2INTEL
1057+
| Op::ArbitraryFloatPowINTEL
1058+
| Op::ArbitraryFloatPowRINTEL
1059+
| Op::ArbitraryFloatPowNINTEL => {
1060+
reserved!(SPV_INTEL_arbitrary_precision_floating_point)
1061+
}
1062+
// TODO these instructions are outdated, and will be replaced by the ones in comments below. When updating, consider merging with the branch above.
1063+
Op::ArbitraryFloatCastINTEL
1064+
| Op::ArbitraryFloatCastFromIntINTEL
1065+
| Op::ArbitraryFloatCastToIntINTEL => {
1066+
// Op::ArbitraryFloatConvertINTEL
1067+
// | Op::ArbitraryFloatConvertFromUIntINTEL
1068+
// | Op::ArbitraryFloatConvertFromSIntINTEL
1069+
// | Op::ArbitraryFloatConvertToUIntINTEL
1070+
// | Op::ArbitraryFloatConvertToSIntINTEL
1071+
reserved!(SPV_INTEL_arbitrary_precision_floating_point)
1072+
}
1073+
// SPV_INTEL_arbitrary_precision_fixed_point
1074+
Op::FixedSqrtINTEL
1075+
| Op::FixedRecipINTEL
1076+
| Op::FixedRsqrtINTEL
1077+
| Op::FixedSinINTEL
1078+
| Op::FixedCosINTEL
1079+
| Op::FixedSinCosINTEL
1080+
| Op::FixedSinPiINTEL
1081+
| Op::FixedCosPiINTEL
1082+
| Op::FixedSinCosPiINTEL
1083+
| Op::FixedLogINTEL
1084+
| Op::FixedExpINTEL => reserved!(SPV_INTEL_arbitrary_precision_fixed_point),
1085+
// SPV_EXT_shader_tile_image
1086+
Op::ColorAttachmentReadEXT | Op::DepthAttachmentReadEXT | Op::StencilAttachmentReadEXT => {
1087+
reserved!(SPV_EXT_shader_tile_image)
1088+
}
1089+
// SPV_KHR_cooperative_matrix
1090+
Op::TypeCooperativeMatrixKHR
1091+
| Op::CooperativeMatrixLoadKHR
1092+
| Op::CooperativeMatrixStoreKHR
1093+
| Op::CooperativeMatrixMulAddKHR
1094+
| Op::CooperativeMatrixLengthKHR => reserved!(SPV_KHR_cooperative_matrix),
1095+
// SPV_QCOM_image_processing
1096+
Op::ImageSampleWeightedQCOM
1097+
| Op::ImageBoxFilterQCOM
1098+
| Op::ImageBlockMatchSSDQCOM
1099+
| Op::ImageBlockMatchSADQCOM => reserved!(SPV_QCOM_image_processing),
1100+
// SPV_AMDX_shader_enqueue
1101+
Op::FinalizeNodePayloadsAMDX
1102+
| Op::FinishWritingNodePayloadAMDX
1103+
| Op::InitializeNodePayloadsAMDX => reserved!(SPV_AMDX_shader_enqueue),
1104+
// SPV_NV_displacement_micromap
1105+
Op::FetchMicroTriangleVertexPositionNV | Op::FetchMicroTriangleVertexBarycentricNV => {
1106+
reserved!(SPV_NV_displacement_micromap)
1107+
}
1108+
// SPV_KHR_ray_tracing_position_fetch
1109+
Op::RayQueryGetIntersectionTriangleVertexPositionsKHR => {
1110+
reserved!(SPV_KHR_ray_tracing_position_fetch)
1111+
}
1112+
// SPV_INTEL_bfloat16_conversion
1113+
Op::ConvertFToBF16INTEL | Op::ConvertBF16ToFINTEL => {
1114+
reserved!(SPV_INTEL_bfloat16_conversion)
1115+
}
1116+
1117+
// TODO unknown_extension_INTEL
1118+
Op::SaveMemoryINTEL
1119+
| Op::RestoreMemoryINTEL
1120+
| Op::AliasDomainDeclINTEL
1121+
| Op::AliasScopeDeclINTEL
1122+
| Op::AliasScopeListDeclINTEL
1123+
| Op::PtrCastToCrossWorkgroupINTEL
1124+
| Op::CrossWorkgroupCastToPtrINTEL
1125+
| Op::TypeBufferSurfaceINTEL
1126+
| Op::TypeStructContinuedINTEL
1127+
| Op::ConstantCompositeContinuedINTEL
1128+
| Op::SpecConstantCompositeContinuedINTEL
1129+
| Op::ControlBarrierArriveINTEL
1130+
| Op::ControlBarrierWaitINTEL => reserved!(unknown_extension_INTEL),
9481131
}
9491132

9501133
None

0 commit comments

Comments
 (0)