Skip to content

Commit fe05765

Browse files
robamlerteoxoy
authored andcommitted
Use intrinsics for dot4{I, U}8Packed in HLSL
1 parent 892f629 commit fe05765

9 files changed

+152
-25
lines changed

naga/src/back/hlsl/writer.rs

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use super::{
1212
WrappedZeroValue,
1313
},
1414
storage::StoreValue,
15-
BackendResult, Error, FragmentEntryPoint, Options,
15+
BackendResult, Error, FragmentEntryPoint, Options, ShaderModel,
1616
};
1717
use crate::{
1818
back::{self, Baked},
@@ -3751,33 +3751,48 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
37513751
fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
37523752
let arg1 = arg1.unwrap();
37533753

3754-
write!(self.out, "dot(")?;
3754+
if self.options.shader_model >= ShaderModel::V6_4 {
3755+
// Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
3756+
let function_name = match fun {
3757+
Function::Dot4I8Packed => "dot4add_i8packed",
3758+
Function::Dot4U8Packed => "dot4add_u8packed",
3759+
_ => unreachable!(),
3760+
};
3761+
write!(self.out, "{function_name}(")?;
3762+
self.write_expr(module, arg, func_ctx)?;
3763+
write!(self.out, ", ")?;
3764+
self.write_expr(module, arg1, func_ctx)?;
3765+
write!(self.out, ", 0)")?;
3766+
} else {
3767+
// Fall back to a polyfill as `dot4add_u8packed` is not available.
3768+
write!(self.out, "dot(")?;
37553769

3756-
if matches!(fun, Function::Dot4U8Packed) {
3757-
write!(self.out, "u")?;
3758-
}
3759-
write!(self.out, "int4(")?;
3760-
self.write_expr(module, arg, func_ctx)?;
3761-
write!(self.out, ", ")?;
3762-
self.write_expr(module, arg, func_ctx)?;
3763-
write!(self.out, " >> 8, ")?;
3764-
self.write_expr(module, arg, func_ctx)?;
3765-
write!(self.out, " >> 16, ")?;
3766-
self.write_expr(module, arg, func_ctx)?;
3767-
write!(self.out, " >> 24) << 24 >> 24, ")?;
3770+
if matches!(fun, Function::Dot4U8Packed) {
3771+
write!(self.out, "u")?;
3772+
}
3773+
write!(self.out, "int4(")?;
3774+
self.write_expr(module, arg, func_ctx)?;
3775+
write!(self.out, ", ")?;
3776+
self.write_expr(module, arg, func_ctx)?;
3777+
write!(self.out, " >> 8, ")?;
3778+
self.write_expr(module, arg, func_ctx)?;
3779+
write!(self.out, " >> 16, ")?;
3780+
self.write_expr(module, arg, func_ctx)?;
3781+
write!(self.out, " >> 24) << 24 >> 24, ")?;
37683782

3769-
if matches!(fun, Function::Dot4U8Packed) {
3770-
write!(self.out, "u")?;
3783+
if matches!(fun, Function::Dot4U8Packed) {
3784+
write!(self.out, "u")?;
3785+
}
3786+
write!(self.out, "int4(")?;
3787+
self.write_expr(module, arg1, func_ctx)?;
3788+
write!(self.out, ", ")?;
3789+
self.write_expr(module, arg1, func_ctx)?;
3790+
write!(self.out, " >> 8, ")?;
3791+
self.write_expr(module, arg1, func_ctx)?;
3792+
write!(self.out, " >> 16, ")?;
3793+
self.write_expr(module, arg1, func_ctx)?;
3794+
write!(self.out, " >> 24) << 24 >> 24)")?;
37713795
}
3772-
write!(self.out, "int4(")?;
3773-
self.write_expr(module, arg1, func_ctx)?;
3774-
write!(self.out, ", ")?;
3775-
self.write_expr(module, arg1, func_ctx)?;
3776-
write!(self.out, " >> 8, ")?;
3777-
self.write_expr(module, arg1, func_ctx)?;
3778-
write!(self.out, " >> 16, ")?;
3779-
self.write_expr(module, arg1, func_ctx)?;
3780-
write!(self.out, " >> 24) << 24 >> 24)")?;
37813796
}
37823797
Function::QuantizeToF16 => {
37833798
write!(self.out, "f16tof32(f32tof16(")?;
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Explicitly turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on HLSL.
2+
3+
targets = "HLSL"
4+
5+
[hlsl]
6+
shader_model = "V6_4"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
fn test_packed_integer_dot_product() -> u32 {
2+
let a_5 = 1u;
3+
let b_5 = 2u;
4+
let c_5: i32 = dot4I8Packed(a_5, b_5);
5+
6+
let a_6 = 3u;
7+
let b_6 = 4u;
8+
let c_6: u32 = dot4U8Packed(a_6, b_6);
9+
10+
// test baking of arguments
11+
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
12+
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
13+
return c_8;
14+
}
15+
16+
@compute @workgroup_size(1)
17+
fn main() {
18+
let c = test_packed_integer_dot_product();
19+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed` on HLSL.
2+
3+
targets = "HLSL"
4+
5+
[hlsl]
6+
shader_model = "V6_3"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
fn test_packed_integer_dot_product() -> u32 {
2+
let a_5 = 1u;
3+
let b_5 = 2u;
4+
let c_5: i32 = dot4I8Packed(a_5, b_5);
5+
6+
let a_6 = 3u;
7+
let b_6 = 4u;
8+
let c_6: u32 = dot4U8Packed(a_6, b_6);
9+
10+
// test baking of arguments
11+
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
12+
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
13+
return c_8;
14+
}
15+
16+
@compute @workgroup_size(1)
17+
fn main() {
18+
let c = test_packed_integer_dot_product();
19+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
uint test_packed_integer_dot_product()
2+
{
3+
int c_5_ = dot4add_i8packed(1u, 2u, 0);
4+
uint c_6_ = dot4add_u8packed(3u, 4u, 0);
5+
uint _e7 = (5u + c_6_);
6+
uint _e9 = (6u + c_6_);
7+
int c_7_ = dot4add_i8packed(_e7, _e9, 0);
8+
uint _e12 = (7u + c_6_);
9+
uint _e14 = (8u + c_6_);
10+
uint c_8_ = dot4add_u8packed(_e12, _e14, 0);
11+
return c_8_;
12+
}
13+
14+
[numthreads(1, 1, 1)]
15+
void main()
16+
{
17+
const uint _e0 = test_packed_integer_dot_product();
18+
return;
19+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
(
2+
vertex:[
3+
],
4+
fragment:[
5+
],
6+
compute:[
7+
(
8+
entry_point:"main",
9+
target_profile:"cs_6_4",
10+
),
11+
],
12+
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
uint test_packed_integer_dot_product()
2+
{
3+
int c_5_ = dot(int4(1u, 1u >> 8, 1u >> 16, 1u >> 24) << 24 >> 24, int4(2u, 2u >> 8, 2u >> 16, 2u >> 24) << 24 >> 24);
4+
uint c_6_ = dot(uint4(3u, 3u >> 8, 3u >> 16, 3u >> 24) << 24 >> 24, uint4(4u, 4u >> 8, 4u >> 16, 4u >> 24) << 24 >> 24);
5+
uint _e7 = (5u + c_6_);
6+
uint _e9 = (6u + c_6_);
7+
int c_7_ = dot(int4(_e7, _e7 >> 8, _e7 >> 16, _e7 >> 24) << 24 >> 24, int4(_e9, _e9 >> 8, _e9 >> 16, _e9 >> 24) << 24 >> 24);
8+
uint _e12 = (7u + c_6_);
9+
uint _e14 = (8u + c_6_);
10+
uint c_8_ = dot(uint4(_e12, _e12 >> 8, _e12 >> 16, _e12 >> 24) << 24 >> 24, uint4(_e14, _e14 >> 8, _e14 >> 16, _e14 >> 24) << 24 >> 24);
11+
return c_8_;
12+
}
13+
14+
[numthreads(1, 1, 1)]
15+
void main()
16+
{
17+
const uint _e0 = test_packed_integer_dot_product();
18+
return;
19+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
(
2+
vertex:[
3+
],
4+
fragment:[
5+
],
6+
compute:[
7+
(
8+
entry_point:"main",
9+
target_profile:"cs_6_3",
10+
),
11+
],
12+
)

0 commit comments

Comments
 (0)