Skip to content

Commit e81b34d

Browse files
committed
Road floats to prevent deltas
1 parent 8b6c3b8 commit e81b34d

File tree

11 files changed

+160
-134
lines changed

11 files changed

+160
-134
lines changed

tests/difftests/lib/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use-compiled-tools = [
1515
"spirv-builder/use-compiled-tools"
1616
]
1717

18-
[dependencies]
18+
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
1919
spirv-builder.workspace = true
2020
serde = { version = "1.0", features = ["derive"] }
2121
serde_json = "1.0"

tests/difftests/lib/src/lib.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1+
#![cfg_attr(target_arch = "spirv", no_std)]
2+
3+
#[cfg(not(target_arch = "spirv"))]
14
pub mod config;
5+
#[cfg(not(target_arch = "spirv"))]
26
pub mod scaffold;
37

8+
/// Macro to round a f32 value to 6 decimal places for cross-platform consistency
9+
/// in floating-point operations. This helps ensure difftest results are consistent
10+
/// across different platforms (Linux, Mac, Windows) which may have slight differences
11+
/// in floating-point implementations.
12+
#[macro_export]
13+
macro_rules! round6 {
14+
($v:expr) => {
15+
(($v) * 1_000_000.0).round() / 1_000_000.0
16+
};
17+
}
18+
419
#[cfg(test)]
520
mod tests {
621
use super::config::Config;

tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@ crate-type = ["dylib"]
1010

1111
# Common deps
1212
[dependencies]
13-
14-
# GPU deps
1513
spirv-std.workspace = true
14+
difftest.workspace = true
1615

1716
# CPU deps
1817
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
19-
difftest.workspace = true
20-
bytemuck.workspace = true
18+
bytemuck.workspace = true

tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![no_std]
22

3+
use difftest::round6;
34
#[allow(unused_imports)]
45
use spirv_std::num_traits::Float;
56
use spirv_std::spirv;
@@ -24,28 +25,28 @@ pub fn main_cs(
2425
}
2526

2627
// Basic arithmetic
27-
output[base_offset + 0] = x + 1.5;
28-
output[base_offset + 1] = x - 0.5;
29-
output[base_offset + 2] = x * 2.0;
30-
output[base_offset + 3] = x / 2.0;
31-
output[base_offset + 4] = x % 3.0;
28+
output[base_offset + 0] = round6!(x + 1.5);
29+
output[base_offset + 1] = round6!(x - 0.5);
30+
output[base_offset + 2] = round6!(x * 2.0);
31+
output[base_offset + 3] = round6!(x / 2.0);
32+
output[base_offset + 4] = round6!(x % 3.0);
3233

3334
// Trigonometric functions (simplified for consistent results)
34-
output[base_offset + 5] = x.sin();
35-
output[base_offset + 6] = x.cos();
36-
output[base_offset + 7] = x.tan().clamp(-10.0, 10.0);
35+
output[base_offset + 5] = round6!(x.sin());
36+
output[base_offset + 6] = round6!(x.cos());
37+
output[base_offset + 7] = round6!(x.tan().clamp(-10.0, 10.0));
3738
output[base_offset + 8] = 0.0;
3839
output[base_offset + 9] = 0.0;
39-
output[base_offset + 10] = x.atan();
40+
output[base_offset + 10] = round6!(x.atan());
4041

4142
// Exponential and logarithmic (simplified)
42-
output[base_offset + 11] = x.exp().min(1e6);
43-
output[base_offset + 12] = if x > 0.0 { x.ln() } else { -10.0 };
44-
output[base_offset + 13] = x.abs().sqrt();
45-
output[base_offset + 14] = x.abs().powf(2.0);
46-
output[base_offset + 15] = if x > 0.0 { x.log2() } else { -10.0 };
47-
output[base_offset + 16] = x.exp2().min(1e6);
48-
output[base_offset + 17] = x.floor();
43+
output[base_offset + 11] = round6!(x.exp().min(1e6));
44+
output[base_offset + 12] = round6!(if x > 0.0 { x.ln() } else { -10.0 });
45+
output[base_offset + 13] = round6!(x.abs().sqrt());
46+
output[base_offset + 14] = round6!(x.abs() * x.abs()); // Use multiplication instead of powf
47+
output[base_offset + 15] = round6!(if x > 0.0 { x.log2() } else { -10.0 });
48+
output[base_offset + 16] = round6!(x.exp2().min(1e6));
49+
output[base_offset + 17] = x.floor(); // floor/ceil/round are exact
4950
output[base_offset + 18] = x.ceil();
5051
output[base_offset + 19] = x.round();
5152

tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ var<storage, read> input: array<f32>;
44
@group(0) @binding(1)
55
var<storage, read_write> output: array<f32>;
66

7+
// Helper function to round to 6 decimal places for cross-platform consistency
8+
fn round6(v: f32) -> f32 {
9+
return round(v * 1000000.0) / 1000000.0;
10+
}
11+
712
@compute @workgroup_size(32, 1, 1)
813
fn main_cs(@builtin(global_invocation_id) global_id: vec3<u32>) {
914
let tid = global_id.x;
@@ -20,28 +25,28 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3<u32>) {
2025
}
2126

2227
// Basic arithmetic
23-
output[base_offset + 0u] = x + 1.5;
24-
output[base_offset + 1u] = x - 0.5;
25-
output[base_offset + 2u] = x * 2.0;
26-
output[base_offset + 3u] = x / 2.0;
27-
output[base_offset + 4u] = x % 3.0;
28+
output[base_offset + 0u] = round6(x + 1.5);
29+
output[base_offset + 1u] = round6(x - 0.5);
30+
output[base_offset + 2u] = round6(x * 2.0);
31+
output[base_offset + 3u] = round6(x / 2.0);
32+
output[base_offset + 4u] = round6(x % 3.0);
2833

2934
// Trigonometric functions (simplified for consistent results)
30-
output[base_offset + 5u] = sin(x);
31-
output[base_offset + 6u] = cos(x);
32-
output[base_offset + 7u] = clamp(tan(x), -10.0, 10.0);
35+
output[base_offset + 5u] = round6(sin(x));
36+
output[base_offset + 6u] = round6(cos(x));
37+
output[base_offset + 7u] = round6(clamp(tan(x), -10.0, 10.0));
3338
output[base_offset + 8u] = 0.0;
3439
output[base_offset + 9u] = 0.0;
35-
output[base_offset + 10u] = atan(x);
40+
output[base_offset + 10u] = round6(atan(x));
3641

3742
// Exponential and logarithmic (simplified)
38-
output[base_offset + 11u] = min(exp(x), 1e6);
39-
output[base_offset + 12u] = select(-10.0, log(x), x > 0.0);
40-
output[base_offset + 13u] = sqrt(abs(x));
41-
output[base_offset + 14u] = pow(abs(x), 2.0);
42-
output[base_offset + 15u] = select(-10.0, log2(x), x > 0.0);
43-
output[base_offset + 16u] = min(exp2(x), 1e6);
44-
output[base_offset + 17u] = floor(x);
43+
output[base_offset + 11u] = round6(min(exp(x), 1e6));
44+
output[base_offset + 12u] = round6(select(-10.0, log(x), x > 0.0));
45+
output[base_offset + 13u] = round6(sqrt(abs(x)));
46+
output[base_offset + 14u] = round6(abs(x) * abs(x)); // Use multiplication instead of pow
47+
output[base_offset + 15u] = round6(select(-10.0, log2(x), x > 0.0));
48+
output[base_offset + 16u] = round6(min(exp2(x), 1e6));
49+
output[base_offset + 17u] = floor(x); // floor/ceil/round are exact
4550
output[base_offset + 18u] = ceil(x);
4651
output[base_offset + 19u] = round(x);
4752

tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@ crate-type = ["dylib"]
1010

1111
# Common deps
1212
[dependencies]
13-
14-
# GPU deps
1513
spirv-std.workspace = true
14+
difftest.workspace = true
1615

1716
# CPU deps
1817
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
19-
difftest.workspace = true
20-
bytemuck.workspace = true
18+
bytemuck.workspace = true

tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![no_std]
22

3+
use difftest::round6;
34
use spirv_std::glam::{Mat2, Mat3, Mat4, UVec3, Vec2, Vec3, Vec4};
45
#[allow(unused_imports)]
56
use spirv_std::num_traits::Float;
@@ -35,10 +36,10 @@ pub fn main_cs(
3536

3637
// Mat2 multiplication
3738
let m2_mul = m2a * m2b;
38-
output[base_offset + 0] = m2_mul.col(0).x;
39-
output[base_offset + 1] = m2_mul.col(0).y;
40-
output[base_offset + 2] = m2_mul.col(1).x;
41-
output[base_offset + 3] = m2_mul.col(1).y;
39+
output[base_offset + 0] = round6!(m2_mul.col(0).x);
40+
output[base_offset + 1] = round6!(m2_mul.col(0).y);
41+
output[base_offset + 2] = round6!(m2_mul.col(1).x);
42+
output[base_offset + 3] = round6!(m2_mul.col(1).y);
4243

4344
// Mat2 transpose
4445
let m2_transpose = m2a.transpose();
@@ -48,29 +49,29 @@ pub fn main_cs(
4849
output[base_offset + 7] = m2_transpose.col(1).y;
4950

5051
// Mat2 determinant (with rounding for consistency)
51-
output[base_offset + 8] = (m2a.determinant() * 1000.0).round() / 1000.0;
52+
output[base_offset + 8] = round6!(m2a.determinant());
5253

5354
// Mat2 * Vec2
5455
let v2 = Vec2::new(1.0, 2.0);
5556
let m2_v2 = m2a * v2;
56-
output[base_offset + 9] = m2_v2.x;
57-
output[base_offset + 10] = m2_v2.y;
57+
output[base_offset + 9] = round6!(m2_v2.x);
58+
output[base_offset + 10] = round6!(m2_v2.y);
5859

5960
// Mat3 operations
6061
let m3a = Mat3::from_cols(Vec3::new(a, b, c), Vec3::new(b, c, d), Vec3::new(c, d, a));
6162
let m3b = Mat3::from_cols(Vec3::new(d, c, b), Vec3::new(c, b, a), Vec3::new(b, a, d));
6263

6364
// Mat3 multiplication
6465
let m3_mul = m3a * m3b;
65-
output[base_offset + 11] = m3_mul.col(0).x;
66-
output[base_offset + 12] = m3_mul.col(0).y;
67-
output[base_offset + 13] = m3_mul.col(0).z;
68-
output[base_offset + 14] = m3_mul.col(1).x;
69-
output[base_offset + 15] = m3_mul.col(1).y;
70-
output[base_offset + 16] = m3_mul.col(1).z;
71-
output[base_offset + 17] = m3_mul.col(2).x;
72-
output[base_offset + 18] = m3_mul.col(2).y;
73-
output[base_offset + 19] = m3_mul.col(2).z;
66+
output[base_offset + 11] = round6!(m3_mul.col(0).x);
67+
output[base_offset + 12] = round6!(m3_mul.col(0).y);
68+
output[base_offset + 13] = round6!(m3_mul.col(0).z);
69+
output[base_offset + 14] = round6!(m3_mul.col(1).x);
70+
output[base_offset + 15] = round6!(m3_mul.col(1).y);
71+
output[base_offset + 16] = round6!(m3_mul.col(1).z);
72+
output[base_offset + 17] = round6!(m3_mul.col(2).x);
73+
output[base_offset + 18] = round6!(m3_mul.col(2).y);
74+
output[base_offset + 19] = round6!(m3_mul.col(2).z);
7475

7576
// Mat3 transpose - store just diagonal elements
7677
let m3_transpose = m3a.transpose();
@@ -79,14 +80,14 @@ pub fn main_cs(
7980
output[base_offset + 22] = m3_transpose.col(2).z;
8081

8182
// Mat3 determinant (with rounding for consistency)
82-
output[base_offset + 23] = (m3a.determinant() * 1000.0).round() / 1000.0;
83+
output[base_offset + 23] = round6!(m3a.determinant());
8384

8485
// Mat3 * Vec3 (with rounding for consistency)
8586
let v3 = Vec3::new(1.0, 2.0, 3.0);
8687
let m3_v3 = m3a * v3;
87-
output[base_offset + 24] = (m3_v3.x * 10000.0).round() / 10000.0;
88-
output[base_offset + 25] = (m3_v3.y * 10000.0).round() / 10000.0;
89-
output[base_offset + 26] = (m3_v3.z * 10000.0).round() / 10000.0;
88+
output[base_offset + 24] = round6!(m3_v3.x);
89+
output[base_offset + 25] = round6!(m3_v3.y);
90+
output[base_offset + 26] = round6!(m3_v3.z);
9091

9192
// Mat4 operations
9293
let m4a = Mat4::from_cols(
@@ -104,10 +105,10 @@ pub fn main_cs(
104105

105106
// Mat4 multiplication (just store diagonal for brevity)
106107
let m4_mul = m4a * m4b;
107-
output[base_offset + 27] = m4_mul.col(0).x;
108-
output[base_offset + 28] = m4_mul.col(1).y;
109-
output[base_offset + 29] = m4_mul.col(2).z;
110-
output[base_offset + 30] = m4_mul.col(3).w;
108+
output[base_offset + 27] = round6!(m4_mul.col(0).x);
109+
output[base_offset + 28] = round6!(m4_mul.col(1).y);
110+
output[base_offset + 29] = round6!(m4_mul.col(2).z);
111+
output[base_offset + 30] = round6!(m4_mul.col(3).w);
111112

112113
// Mat4 transpose (just store diagonal)
113114
let m4_transpose = m4a.transpose();
@@ -117,15 +118,15 @@ pub fn main_cs(
117118
output[base_offset + 34] = m4_transpose.col(3).w;
118119

119120
// Mat4 determinant (with rounding for consistency)
120-
output[base_offset + 35] = (m4a.determinant() * 1000.0).round() / 1000.0;
121+
output[base_offset + 35] = round6!(m4a.determinant());
121122

122123
// Mat4 * Vec4 (with rounding for consistency)
123124
let v4 = Vec4::new(1.0, 2.0, 3.0, 4.0);
124125
let m4_v4 = m4a * v4;
125-
output[base_offset + 36] = (m4_v4.x * 10000.0).round() / 10000.0;
126-
output[base_offset + 37] = (m4_v4.y * 10000.0).round() / 10000.0;
127-
output[base_offset + 38] = (m4_v4.z * 10000.0).round() / 10000.0;
128-
output[base_offset + 39] = (m4_v4.w * 10000.0).round() / 10000.0;
126+
output[base_offset + 36] = round6!(m4_v4.x);
127+
output[base_offset + 37] = round6!(m4_v4.y);
128+
output[base_offset + 38] = round6!(m4_v4.z);
129+
output[base_offset + 39] = round6!(m4_v4.w);
129130

130131
// Identity matrices
131132
output[base_offset + 40] = Mat2::IDENTITY.col(0).x;
@@ -135,8 +136,8 @@ pub fn main_cs(
135136
// Matrix inverse
136137
if m2a.determinant().abs() > 0.0001 {
137138
let m2_inv = m2a.inverse();
138-
output[base_offset + 43] = m2_inv.col(0).x;
139-
output[base_offset + 44] = m2_inv.col(1).y;
139+
output[base_offset + 43] = round6!(m2_inv.col(0).x);
140+
output[base_offset + 44] = round6!(m2_inv.col(1).y);
140141
} else {
141142
output[base_offset + 43] = 0.0;
142143
output[base_offset + 44] = 0.0;

0 commit comments

Comments
 (0)