Skip to content

Commit 5dd0200

Browse files
committed
subgroup: add trait VectorOrScalar, representing either a vector or a scalar type
1 parent 4c9718f commit 5dd0200

File tree

5 files changed

+106
-5
lines changed

5 files changed

+106
-5
lines changed

crates/spirv-std/src/float.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Traits and helper functions related to floats.
22
3+
use crate::scalar::VectorOrScalar;
34
use crate::vector::Vector;
45
#[cfg(target_arch = "spirv")]
56
use core::arch::asm;
@@ -71,6 +72,9 @@ struct F32x2 {
7172
x: f32,
7273
y: f32,
7374
}
75+
unsafe impl VectorOrScalar for F32x2 {
76+
type Scalar = f32;
77+
}
7478
unsafe impl Vector<f32, 2> for F32x2 {}
7579

7680
/// Converts an f32 (float) into an f16 (half). The result is a u32, not a u16, due to GPU support

crates/spirv-std/src/scalar.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,58 @@
11
//! Traits related to scalars.
22
3+
/// Abstract trait representing either a vector or a scalar type.
4+
///
5+
/// # Safety
6+
/// Implementing this trait on non-scalar or non-vector types may break assumptions about other
7+
/// unsafe code, and should not be done.
8+
pub unsafe trait VectorOrScalar: Default {
9+
/// Either the scalar component type of the vector or the scalar itself.
10+
type Scalar: Scalar;
11+
}
12+
13+
unsafe impl VectorOrScalar for bool {
14+
type Scalar = bool;
15+
}
16+
unsafe impl VectorOrScalar for f32 {
17+
type Scalar = f32;
18+
}
19+
unsafe impl VectorOrScalar for f64 {
20+
type Scalar = f64;
21+
}
22+
unsafe impl VectorOrScalar for u8 {
23+
type Scalar = u8;
24+
}
25+
unsafe impl VectorOrScalar for u16 {
26+
type Scalar = u16;
27+
}
28+
unsafe impl VectorOrScalar for u32 {
29+
type Scalar = u32;
30+
}
31+
unsafe impl VectorOrScalar for u64 {
32+
type Scalar = u64;
33+
}
34+
unsafe impl VectorOrScalar for i8 {
35+
type Scalar = i8;
36+
}
37+
unsafe impl VectorOrScalar for i16 {
38+
type Scalar = i16;
39+
}
40+
unsafe impl VectorOrScalar for i32 {
41+
type Scalar = i32;
42+
}
43+
unsafe impl VectorOrScalar for i64 {
44+
type Scalar = i64;
45+
}
46+
347
/// Abstract trait representing a SPIR-V scalar type.
448
///
549
/// # Safety
650
/// Implementing this trait on non-scalar types breaks assumptions of other unsafe code, and should
751
/// not be done.
8-
pub unsafe trait Scalar: Copy + Default + crate::sealed::Sealed {}
52+
pub unsafe trait Scalar:
53+
VectorOrScalar<Scalar = Self> + Copy + Default + crate::sealed::Sealed
54+
{
55+
}
956

1057
unsafe impl Scalar for bool {}
1158
unsafe impl Scalar for f32 {}

crates/spirv-std/src/vector.rs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,57 @@
11
//! Traits related to vectors.
22
3+
use crate::scalar::{Scalar, VectorOrScalar};
34
use glam::{Vec3Swizzles, Vec4Swizzles};
45

6+
unsafe impl VectorOrScalar for glam::Vec2 {
7+
type Scalar = f32;
8+
}
9+
unsafe impl VectorOrScalar for glam::Vec3 {
10+
type Scalar = f32;
11+
}
12+
unsafe impl VectorOrScalar for glam::Vec3A {
13+
type Scalar = f32;
14+
}
15+
unsafe impl VectorOrScalar for glam::Vec4 {
16+
type Scalar = f32;
17+
}
18+
19+
unsafe impl VectorOrScalar for glam::DVec2 {
20+
type Scalar = f64;
21+
}
22+
unsafe impl VectorOrScalar for glam::DVec3 {
23+
type Scalar = f64;
24+
}
25+
unsafe impl VectorOrScalar for glam::DVec4 {
26+
type Scalar = f64;
27+
}
28+
29+
unsafe impl VectorOrScalar for glam::UVec2 {
30+
type Scalar = u32;
31+
}
32+
unsafe impl VectorOrScalar for glam::UVec3 {
33+
type Scalar = u32;
34+
}
35+
unsafe impl VectorOrScalar for glam::UVec4 {
36+
type Scalar = u32;
37+
}
38+
39+
unsafe impl VectorOrScalar for glam::IVec2 {
40+
type Scalar = i32;
41+
}
42+
unsafe impl VectorOrScalar for glam::IVec3 {
43+
type Scalar = i32;
44+
}
45+
unsafe impl VectorOrScalar for glam::IVec4 {
46+
type Scalar = i32;
47+
}
48+
549
/// Abstract trait representing a SPIR-V vector type.
650
///
751
/// # Safety
852
/// Implementing this trait on non-simd-vector types breaks assumptions of other unsafe code, and
953
/// should not be done.
10-
pub unsafe trait Vector<T: crate::scalar::Scalar, const N: usize>: Default {}
54+
pub unsafe trait Vector<T: Scalar, const N: usize>: VectorOrScalar<Scalar = T> {}
1155

1256
unsafe impl Vector<f32, 2> for glam::Vec2 {}
1357
unsafe impl Vector<f32, 3> for glam::Vec3 {}
@@ -27,7 +71,7 @@ unsafe impl Vector<i32, 3> for glam::IVec3 {}
2771
unsafe impl Vector<i32, 4> for glam::IVec4 {}
2872

2973
/// Trait that implements slicing of a vector into a scalar or vector of lower dimensions, by
30-
/// ignoring the highter dimensions
74+
/// ignoring the higter dimensions
3175
pub trait VectorTruncateInto<T> {
3276
/// Slices the vector into a lower dimensional type by ignoring the higher components
3377
fn truncate_into(self) -> T;

tests/ui/arch/all.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#![feature(repr_simd)]
44

55
use spirv_std::spirv;
6-
use spirv_std::{scalar::Scalar, vector::Vector};
6+
use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
77

88
/// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members
99
/// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()`
@@ -12,6 +12,9 @@ use spirv_std::{scalar::Scalar, vector::Vector};
1212
/// it (for now at least)
1313
#[repr(simd)]
1414
struct Vec2<T>(T, T);
15+
unsafe impl<T: Scalar> VectorOrScalar for Vec2<T> {
16+
type Scalar = T;
17+
}
1518
unsafe impl<T: Scalar> Vector<T, 2> for Vec2<T> {}
1619

1720
impl<T: Scalar> Default for Vec2<T> {

tests/ui/arch/any.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#![feature(repr_simd)]
44

55
use spirv_std::spirv;
6-
use spirv_std::{scalar::Scalar, vector::Vector};
6+
use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
77

88
/// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members
99
/// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()`
@@ -12,6 +12,9 @@ use spirv_std::{scalar::Scalar, vector::Vector};
1212
/// it (for now at least)
1313
#[repr(simd)]
1414
struct Vec2<T>(T, T);
15+
unsafe impl<T: Scalar> VectorOrScalar for Vec2<T> {
16+
type Scalar = T;
17+
}
1518
unsafe impl<T: Scalar> Vector<T, 2> for Vec2<T> {}
1619

1720
impl<T: Scalar> Default for Vec2<T> {

0 commit comments

Comments
 (0)