Skip to content

Commit 92dc64b

Browse files
authored
Make byte addressable buffer take &self, add support for matrix (#749)
1 parent 95da898 commit 92dc64b

File tree

8 files changed

+50
-33
lines changed

8 files changed

+50
-33
lines changed

crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
5151
}
5252

5353
#[allow(clippy::too_many_arguments)]
54-
fn load_vec_or_arr(
54+
fn load_vec_mat_arr(
5555
&mut self,
5656
original_type: Word,
5757
result_type: Word,
@@ -104,21 +104,22 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
104104
let val = self.load_u32(array, dynamic_word_index, constant_word_offset);
105105
self.bitcast(val, result_type)
106106
}
107-
SpirvType::Vector { element, count } => self.load_vec_or_arr(
108-
original_type,
109-
result_type,
110-
array,
111-
dynamic_word_index,
112-
constant_word_offset,
113-
element,
114-
count,
115-
),
107+
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
108+
.load_vec_mat_arr(
109+
original_type,
110+
result_type,
111+
array,
112+
dynamic_word_index,
113+
constant_word_offset,
114+
element,
115+
count,
116+
),
116117
SpirvType::Array { element, count } => {
117118
let count = match self.builder.lookup_const_u64(count) {
118119
Some(count) => count as u32,
119120
None => return self.load_err(original_type, result_type),
120121
};
121-
self.load_vec_or_arr(
122+
self.load_vec_mat_arr(
122123
original_type,
123124
result_type,
124125
array,
@@ -229,7 +230,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
229230
}
230231

231232
#[allow(clippy::too_many_arguments)]
232-
fn store_vec_or_arr(
233+
fn store_vec_mat_arr(
233234
&mut self,
234235
original_type: Word,
235236
value: SpirvValue,
@@ -278,21 +279,22 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
278279
let value_u32 = self.bitcast(value, u32_ty);
279280
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32);
280281
}
281-
SpirvType::Vector { element, count } => self.store_vec_or_arr(
282-
original_type,
283-
value,
284-
array,
285-
dynamic_word_index,
286-
constant_word_offset,
287-
element,
288-
count,
289-
),
282+
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
283+
.store_vec_mat_arr(
284+
original_type,
285+
value,
286+
array,
287+
dynamic_word_index,
288+
constant_word_offset,
289+
element,
290+
count,
291+
),
290292
SpirvType::Array { element, count } => {
291293
let count = match self.builder.lookup_const_u64(count) {
292294
Some(count) => count as u32,
293295
None => return self.store_err(original_type, value),
294296
};
295-
self.store_vec_or_arr(
297+
self.store_vec_mat_arr(
296298
original_type,
297299
value,
298300
array,

crates/spirv-std/src/byte_addressable_buffer.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ impl<'a> ByteAddressableBuffer<'a> {
4343
/// This function allows writing a type to an untyped buffer, then reading a different type
4444
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
4545
/// transmute)
46-
pub unsafe fn load<T>(self, byte_index: u32) -> T {
46+
pub unsafe fn load<T>(&self, byte_index: u32) -> T {
4747
if byte_index + mem::size_of::<T>() as u32 > self.data.len() as u32 {
4848
panic!("Index out of range")
4949
}
@@ -58,7 +58,7 @@ impl<'a> ByteAddressableBuffer<'a> {
5858
/// This function allows writing a type to an untyped buffer, then reading a different type
5959
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
6060
/// transmute). Additionally, bounds checking is not performed.
61-
pub unsafe fn load_unchecked<T>(self, byte_index: u32) -> T {
61+
pub unsafe fn load_unchecked<T>(&self, byte_index: u32) -> T {
6262
buffer_load_intrinsic(self.data, byte_index)
6363
}
6464

@@ -69,7 +69,7 @@ impl<'a> ByteAddressableBuffer<'a> {
6969
/// This function allows writing a type to an untyped buffer, then reading a different type
7070
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
7171
/// transmute)
72-
pub unsafe fn store<T>(self, byte_index: u32, value: T) {
72+
pub unsafe fn store<T>(&mut self, byte_index: u32, value: T) {
7373
if byte_index + mem::size_of::<T>() as u32 > self.data.len() as u32 {
7474
panic!("Index out of range")
7575
}
@@ -84,7 +84,7 @@ impl<'a> ByteAddressableBuffer<'a> {
8484
/// This function allows writing a type to an untyped buffer, then reading a different type
8585
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
8686
/// transmute). Additionally, bounds checking is not performed.
87-
pub unsafe fn store_unchecked<T>(self, byte_index: u32, value: T) {
87+
pub unsafe fn store_unchecked<T>(&mut self, byte_index: u32, value: T) {
8888
buffer_store_intrinsic(self.data, byte_index, value);
8989
}
9090
}

tests/ui/byte_addressable_buffer/arr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub fn store(
1919
val: [i32; 4],
2020
) {
2121
unsafe {
22-
let buf = ByteAddressableBuffer::new(buf);
22+
let mut buf = ByteAddressableBuffer::new(buf);
2323
buf.store(5, val);
2424
}
2525
}

tests/ui/byte_addressable_buffer/big_struct.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub fn store(
2828
val: BigStruct,
2929
) {
3030
unsafe {
31-
let buf = ByteAddressableBuffer::new(buf);
31+
let mut buf = ByteAddressableBuffer::new(buf);
3232
buf.store(5, val);
3333
}
3434
}

tests/ui/byte_addressable_buffer/complex.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub fn store(
3434
val: Nesty,
3535
) {
3636
unsafe {
37-
let buf = ByteAddressableBuffer::new(buf);
37+
let mut buf = ByteAddressableBuffer::new(buf);
3838
buf.store(5, val);
3939
}
4040
}

tests/ui/byte_addressable_buffer/f32.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub fn load(
1616
#[spirv(fragment)]
1717
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: f32) {
1818
unsafe {
19-
let buf = ByteAddressableBuffer::new(buf);
19+
let mut buf = ByteAddressableBuffer::new(buf);
2020
buf.store(5, val);
2121
}
2222
}

tests/ui/byte_addressable_buffer/u32.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub fn load(
1616
#[spirv(fragment)]
1717
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: u32) {
1818
unsafe {
19-
let buf = ByteAddressableBuffer::new(buf);
19+
let mut buf = ByteAddressableBuffer::new(buf);
2020
buf.store(5, val);
2121
}
2222
}

tests/ui/byte_addressable_buffer/vec.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,36 @@
22

33
use spirv_std::{glam::Vec4, ByteAddressableBuffer};
44

5+
#[spirv(matrix)]
6+
pub struct Mat4 {
7+
x: Vec4,
8+
y: Vec4,
9+
z: Vec4,
10+
w: Vec4,
11+
}
12+
513
#[spirv(fragment)]
614
pub fn load(
715
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
816
out: &mut Vec4,
17+
outmat: &mut Mat4,
918
) {
1019
unsafe {
1120
let buf = ByteAddressableBuffer::new(buf);
1221
*out = buf.load(5);
22+
*outmat = buf.load(5);
1323
}
1424
}
1525

1626
#[spirv(fragment)]
17-
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: Vec4) {
27+
pub fn store(
28+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
29+
val: Vec4,
30+
valmat: Mat4,
31+
) {
1832
unsafe {
19-
let buf = ByteAddressableBuffer::new(buf);
33+
let mut buf = ByteAddressableBuffer::new(buf);
2034
buf.store(5, val);
35+
buf.store(5, valmat);
2136
}
2237
}

0 commit comments

Comments
 (0)