Skip to content

Commit 5b20979

Browse files
robamlerteoxoy
authored andcommitted
Use intrinsics for dot4{I, U}8Packed on spv
1 parent fe05765 commit 5b20979

File tree

7 files changed

+310
-139
lines changed

7 files changed

+310
-139
lines changed

naga/src/back/spv/block.rs

Lines changed: 77 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,59 +1143,88 @@ impl BlockContext<'_> {
11431143
),
11441144
},
11451145
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1146-
// TODO: consider using packed integer dot product if PackedVectorFormat4x8Bit is available
1147-
let (extract_op, arg0_id, arg1_id) = match fun {
1148-
Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1149-
Mf::Dot4I8Packed => {
1150-
// Convert both packed arguments to signed integers so that we can apply the
1151-
// `BitFieldSExtract` operation on them in `write_dot_product` below.
1152-
let new_arg0_id = self.gen_id();
1153-
block.body.push(Instruction::unary(
1154-
spirv::Op::Bitcast,
1155-
result_type_id,
1156-
new_arg0_id,
1157-
arg0_id,
1158-
));
1146+
if self
1147+
.writer
1148+
.require_all(&[
1149+
spirv::Capability::DotProduct,
1150+
spirv::Capability::DotProductInput4x8BitPacked,
1151+
])
1152+
.is_ok()
1153+
{
1154+
// Write optimized code using `PackedVectorFormat4x8Bit`.
1155+
self.writer.use_extension("SPV_KHR_integer_dot_product");
1156+
1157+
let op = match fun {
1158+
Mf::Dot4I8Packed => spirv::Op::SDot,
1159+
Mf::Dot4U8Packed => spirv::Op::UDot,
1160+
_ => unreachable!(),
1161+
};
11591162

1160-
let new_arg1_id = self.gen_id();
1161-
block.body.push(Instruction::unary(
1162-
spirv::Op::Bitcast,
1163-
result_type_id,
1164-
new_arg1_id,
1165-
arg1_id,
1166-
));
1163+
block.body.push(Instruction::ternary(
1164+
op,
1165+
result_type_id,
1166+
id,
1167+
arg0_id,
1168+
arg1_id,
1169+
spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1170+
));
1171+
} else {
1172+
// Fall back to a polyfill since `PackedVectorFormat4x8Bit` is not available.
1173+
let (extract_op, arg0_id, arg1_id) = match fun {
1174+
Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1175+
Mf::Dot4I8Packed => {
1176+
// Convert both packed arguments to signed integers so that we can apply the
1177+
// `BitFieldSExtract` operation on them in `write_dot_product` below.
1178+
let new_arg0_id = self.gen_id();
1179+
block.body.push(Instruction::unary(
1180+
spirv::Op::Bitcast,
1181+
result_type_id,
1182+
new_arg0_id,
1183+
arg0_id,
1184+
));
11671185

1168-
(spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1169-
}
1170-
_ => unreachable!(),
1171-
};
1186+
let new_arg1_id = self.gen_id();
1187+
block.body.push(Instruction::unary(
1188+
spirv::Op::Bitcast,
1189+
result_type_id,
1190+
new_arg1_id,
1191+
arg1_id,
1192+
));
11721193

1173-
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1194+
(spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1195+
}
1196+
_ => unreachable!(),
1197+
};
11741198

1175-
const VEC_LENGTH: u8 = 4;
1176-
let bit_shifts: [_; VEC_LENGTH as usize] = core::array::from_fn(|index| {
1177-
self.writer
1178-
.get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1179-
});
1199+
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1200+
1201+
const VEC_LENGTH: u8 = 4;
1202+
let bit_shifts: [_; VEC_LENGTH as usize] =
1203+
core::array::from_fn(|index| {
1204+
self.writer
1205+
.get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1206+
});
1207+
1208+
self.write_dot_product(
1209+
id,
1210+
result_type_id,
1211+
arg0_id,
1212+
arg1_id,
1213+
VEC_LENGTH as Word,
1214+
block,
1215+
|result_id, composite_id, index| {
1216+
Instruction::ternary(
1217+
extract_op,
1218+
result_type_id,
1219+
result_id,
1220+
composite_id,
1221+
bit_shifts[index as usize],
1222+
eight,
1223+
)
1224+
},
1225+
);
1226+
}
11801227

1181-
self.write_dot_product(
1182-
id,
1183-
result_type_id,
1184-
arg0_id,
1185-
arg1_id,
1186-
VEC_LENGTH as Word,
1187-
block,
1188-
|result_id, composite_id, index| {
1189-
Instruction::ternary(
1190-
extract_op,
1191-
result_type_id,
1192-
result_id,
1193-
composite_id,
1194-
bit_shifts[index as usize],
1195-
eight,
1196-
)
1197-
},
1198-
);
11991228
self.cached[expr_handle] = id;
12001229
return Ok(());
12011230
}

naga/src/back/spv/writer.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,43 @@ impl Writer {
202202
}
203203
}
204204

205+
/// Indicate that the code requires all of the listed capabilities.
206+
///
207+
/// If all entries of `capabilities` appear in the available capabilities
208+
/// specified in the [`Options`] from which this `Writer` was created
209+
/// (including the case where [`Options::capabilities`] is `None`), add
210+
/// them all to this `Writer`'s [`capabilities_used`] table, and return
211+
/// `Ok(())`. If at least one of the listed capabilities is not available,
212+
/// do not add anything to the `capabilities_used` table, and return the
213+
/// first unavailable requested capability, wrapped in `Err()`.
214+
///
215+
/// This method is does not return an [`enum@Error`] in case of failure
216+
/// because it may be used in cases where the caller can recover (e.g.,
217+
/// with a polyfill) if the requested capabilities are not available. In
218+
/// this case, it would be unnecessary work to find *all* the unavailable
219+
/// requested capabilities, and to allocate a `Vec` for them, just so we
220+
/// could return an [`Error::MissingCapabilities`]).
221+
///
222+
/// [`capabilities_used`]: Writer::capabilities_used
223+
pub(super) fn require_all(
224+
&mut self,
225+
capabilities: &[spirv::Capability],
226+
) -> Result<(), spirv::Capability> {
227+
if let Some(ref available) = self.capabilities_available {
228+
for requested in capabilities {
229+
if !available.contains(requested) {
230+
return Err(*requested);
231+
}
232+
}
233+
}
234+
235+
for requested in capabilities {
236+
self.capabilities_used.insert(*requested);
237+
}
238+
239+
Ok(())
240+
}
241+
205242
/// Indicate that the code uses the given extension.
206243
pub(super) fn use_extension(&mut self, extension: &'static str) {
207244
self.extensions_used.insert(extension);
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
# Explicitly turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on HLSL.
1+
# Explicitly turn on optimizations for `dot4I8Packed` and `dot4U8Packed`
2+
# on SPIRV and HLSL.
23

3-
targets = "HLSL"
4+
targets = "SPIRV | HLSL"
5+
6+
[spv]
7+
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
48

59
[hlsl]
610
shader_model = "V6_4"
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed` on HLSL.
1+
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed`
2+
# on SPIRV and HLSL.
23

3-
targets = "HLSL"
4+
targets = "SPIRV | HLSL"
5+
6+
[spv]
7+
# Provide some unrelated capability because an empty list of capabilities would
8+
# get mapped to `None`, which would then be interpreted as "all capabilities
9+
# are available".
10+
capabilities = ["Matrix"]
411

512
[hlsl]
613
shader_model = "V6_3"
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
; SPIR-V
2+
; Version: 1.1
3+
; Generator: rspirv
4+
; Bound: 30
5+
OpCapability Shader
6+
OpCapability DotProductKHR
7+
OpCapability DotProductInput4x8BitPackedKHR
8+
OpExtension "SPV_KHR_integer_dot_product"
9+
%1 = OpExtInstImport "GLSL.std.450"
10+
OpMemoryModel Logical GLSL450
11+
OpEntryPoint GLCompute %26 "main"
12+
OpExecutionMode %26 LocalSize 1 1 1
13+
%2 = OpTypeVoid
14+
%3 = OpTypeInt 32 0
15+
%6 = OpTypeFunction %3
16+
%7 = OpConstant %3 1
17+
%8 = OpConstant %3 2
18+
%9 = OpConstant %3 3
19+
%10 = OpConstant %3 4
20+
%11 = OpConstant %3 5
21+
%12 = OpConstant %3 6
22+
%13 = OpConstant %3 7
23+
%14 = OpConstant %3 8
24+
%16 = OpTypeInt 32 1
25+
%27 = OpTypeFunction %2
26+
%5 = OpFunction %3 None %6
27+
%4 = OpLabel
28+
OpBranch %15
29+
%15 = OpLabel
30+
%17 = OpSDotKHR %16 %7 %8 PackedVectorFormat4x8BitKHR
31+
%18 = OpUDotKHR %3 %9 %10 PackedVectorFormat4x8BitKHR
32+
%19 = OpIAdd %3 %11 %18
33+
%20 = OpIAdd %3 %12 %18
34+
%21 = OpSDotKHR %16 %19 %20 PackedVectorFormat4x8BitKHR
35+
%22 = OpIAdd %3 %13 %18
36+
%23 = OpIAdd %3 %14 %18
37+
%24 = OpUDotKHR %3 %22 %23 PackedVectorFormat4x8BitKHR
38+
OpReturnValue %24
39+
OpFunctionEnd
40+
%26 = OpFunction %2 None %27
41+
%25 = OpLabel
42+
OpBranch %28
43+
%28 = OpLabel
44+
%29 = OpFunctionCall %3 %5
45+
OpReturn
46+
OpFunctionEnd
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
; SPIR-V
2+
; Version: 1.1
3+
; Generator: rspirv
4+
; Bound: 99
5+
OpCapability Shader
6+
%1 = OpExtInstImport "GLSL.std.450"
7+
OpMemoryModel Logical GLSL450
8+
OpEntryPoint GLCompute %95 "main"
9+
OpExecutionMode %95 LocalSize 1 1 1
10+
%2 = OpTypeVoid
11+
%3 = OpTypeInt 32 0
12+
%6 = OpTypeFunction %3
13+
%7 = OpConstant %3 1
14+
%8 = OpConstant %3 2
15+
%9 = OpConstant %3 3
16+
%10 = OpConstant %3 4
17+
%11 = OpConstant %3 5
18+
%12 = OpConstant %3 6
19+
%13 = OpConstant %3 7
20+
%14 = OpConstant %3 8
21+
%16 = OpTypeInt 32 1
22+
%20 = OpConstant %3 0
23+
%21 = OpConstant %3 16
24+
%22 = OpConstant %3 24
25+
%23 = OpConstantNull %16
26+
%40 = OpConstantNull %3
27+
%96 = OpTypeFunction %2
28+
%5 = OpFunction %3 None %6
29+
%4 = OpLabel
30+
OpBranch %15
31+
%15 = OpLabel
32+
%18 = OpBitcast %16 %7
33+
%19 = OpBitcast %16 %8
34+
%24 = OpBitFieldSExtract %16 %18 %20 %14
35+
%25 = OpBitFieldSExtract %16 %19 %20 %14
36+
%26 = OpIMul %16 %24 %25
37+
%27 = OpIAdd %16 %23 %26
38+
%28 = OpBitFieldSExtract %16 %18 %14 %14
39+
%29 = OpBitFieldSExtract %16 %19 %14 %14
40+
%30 = OpIMul %16 %28 %29
41+
%31 = OpIAdd %16 %27 %30
42+
%32 = OpBitFieldSExtract %16 %18 %21 %14
43+
%33 = OpBitFieldSExtract %16 %19 %21 %14
44+
%34 = OpIMul %16 %32 %33
45+
%35 = OpIAdd %16 %31 %34
46+
%36 = OpBitFieldSExtract %16 %18 %22 %14
47+
%37 = OpBitFieldSExtract %16 %19 %22 %14
48+
%38 = OpIMul %16 %36 %37
49+
%17 = OpIAdd %16 %35 %38
50+
%41 = OpBitFieldUExtract %3 %9 %20 %14
51+
%42 = OpBitFieldUExtract %3 %10 %20 %14
52+
%43 = OpIMul %3 %41 %42
53+
%44 = OpIAdd %3 %40 %43
54+
%45 = OpBitFieldUExtract %3 %9 %14 %14
55+
%46 = OpBitFieldUExtract %3 %10 %14 %14
56+
%47 = OpIMul %3 %45 %46
57+
%48 = OpIAdd %3 %44 %47
58+
%49 = OpBitFieldUExtract %3 %9 %21 %14
59+
%50 = OpBitFieldUExtract %3 %10 %21 %14
60+
%51 = OpIMul %3 %49 %50
61+
%52 = OpIAdd %3 %48 %51
62+
%53 = OpBitFieldUExtract %3 %9 %22 %14
63+
%54 = OpBitFieldUExtract %3 %10 %22 %14
64+
%55 = OpIMul %3 %53 %54
65+
%39 = OpIAdd %3 %52 %55
66+
%56 = OpIAdd %3 %11 %39
67+
%57 = OpIAdd %3 %12 %39
68+
%59 = OpBitcast %16 %56
69+
%60 = OpBitcast %16 %57
70+
%61 = OpBitFieldSExtract %16 %59 %20 %14
71+
%62 = OpBitFieldSExtract %16 %60 %20 %14
72+
%63 = OpIMul %16 %61 %62
73+
%64 = OpIAdd %16 %23 %63
74+
%65 = OpBitFieldSExtract %16 %59 %14 %14
75+
%66 = OpBitFieldSExtract %16 %60 %14 %14
76+
%67 = OpIMul %16 %65 %66
77+
%68 = OpIAdd %16 %64 %67
78+
%69 = OpBitFieldSExtract %16 %59 %21 %14
79+
%70 = OpBitFieldSExtract %16 %60 %21 %14
80+
%71 = OpIMul %16 %69 %70
81+
%72 = OpIAdd %16 %68 %71
82+
%73 = OpBitFieldSExtract %16 %59 %22 %14
83+
%74 = OpBitFieldSExtract %16 %60 %22 %14
84+
%75 = OpIMul %16 %73 %74
85+
%58 = OpIAdd %16 %72 %75
86+
%76 = OpIAdd %3 %13 %39
87+
%77 = OpIAdd %3 %14 %39
88+
%79 = OpBitFieldUExtract %3 %76 %20 %14
89+
%80 = OpBitFieldUExtract %3 %77 %20 %14
90+
%81 = OpIMul %3 %79 %80
91+
%82 = OpIAdd %3 %40 %81
92+
%83 = OpBitFieldUExtract %3 %76 %14 %14
93+
%84 = OpBitFieldUExtract %3 %77 %14 %14
94+
%85 = OpIMul %3 %83 %84
95+
%86 = OpIAdd %3 %82 %85
96+
%87 = OpBitFieldUExtract %3 %76 %21 %14
97+
%88 = OpBitFieldUExtract %3 %77 %21 %14
98+
%89 = OpIMul %3 %87 %88
99+
%90 = OpIAdd %3 %86 %89
100+
%91 = OpBitFieldUExtract %3 %76 %22 %14
101+
%92 = OpBitFieldUExtract %3 %77 %22 %14
102+
%93 = OpIMul %3 %91 %92
103+
%78 = OpIAdd %3 %90 %93
104+
OpReturnValue %78
105+
OpFunctionEnd
106+
%95 = OpFunction %2 None %96
107+
%94 = OpLabel
108+
OpBranch %97
109+
%97 = OpLabel
110+
%98 = OpFunctionCall %3 %5
111+
OpReturn
112+
OpFunctionEnd

0 commit comments

Comments
 (0)