Skip to content

Commit f0ff173

Browse files
authored
[naga] Implement constant evaluation for the cross builtin. (#7404)
Add support for `naga::ir::MathFunction::Cross` to `naga::proc::constant_evaluator`. In the tests: - Change `naga/tests/in/wgsl/cross.wgsl` to use more interesting argument values. Rather than passing the same vector twice, which yields a cross product of zero, pass in the x and y unit vectors, whose cross product is the z unit vector. Update snapshot output. - Replace `validation::bad_cross_builtin_args` with a new test, `builtin_cross_product_args`, that is written more in the style of the other tests in this module, and does not depend on the WGSL front end. Because this PR changes the behavior of the constant evaluator, this test stopped behaving as expected. - In `wgsl_errors::check`, move a `panic!` out of a closure so that the `#[track_caller]` attribute works properly.
1 parent 10cd1cc commit f0ff173

File tree

9 files changed

+204
-60
lines changed

9 files changed

+204
-60
lines changed

naga/src/proc/constant_evaluator.rs

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ impl ExpressionKindTracker {
474474
fun,
475475
Mf::Dot
476476
| Mf::Outer
477-
| Mf::Cross
478477
| Mf::Distance
479478
| Mf::Length
480479
| Mf::Normalize
@@ -1347,12 +1346,116 @@ impl<'a> ConstantEvaluator<'a> {
13471346
component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
13481347
}
13491348

1349+
// vector
1350+
crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1351+
13501352
fun => Err(ConstantEvaluatorError::NotImplemented(format!(
13511353
"{fun:?} built-in function"
13521354
))),
13531355
}
13541356
}
13551357

1358+
/// Vector cross product.
1359+
fn cross_product(
1360+
&mut self,
1361+
a: Handle<Expression>,
1362+
b: Handle<Expression>,
1363+
span: Span,
1364+
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1365+
use Literal as Li;
1366+
1367+
let (a, ty) = self.extract_vec::<3>(a)?;
1368+
let (b, _) = self.extract_vec::<3>(b)?;
1369+
1370+
let product = match (a, b) {
1371+
(
1372+
[Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1373+
[Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1374+
) => {
1375+
// `cross` has no overload for AbstractInt, so AbstractInt
1376+
// arguments are automatically converted to AbstractFloat. Since
1377+
// `f64` has a much wider range than `i64`, there's no danger of
1378+
// overflow here.
1379+
let p = cross_product(
1380+
[a0 as f64, a1 as f64, a2 as f64],
1381+
[b0 as f64, b1 as f64, b2 as f64],
1382+
);
1383+
[
1384+
Li::AbstractFloat(p[0]),
1385+
Li::AbstractFloat(p[1]),
1386+
Li::AbstractFloat(p[2]),
1387+
]
1388+
}
1389+
(
1390+
[Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1391+
[Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1392+
) => {
1393+
let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1394+
[
1395+
Li::AbstractFloat(p[0]),
1396+
Li::AbstractFloat(p[1]),
1397+
Li::AbstractFloat(p[2]),
1398+
]
1399+
}
1400+
([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1401+
let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1402+
[Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1403+
}
1404+
([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1405+
let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1406+
[Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1407+
}
1408+
([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1409+
let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1410+
[Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1411+
}
1412+
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
1413+
};
1414+
1415+
let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1416+
let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1417+
let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1418+
1419+
self.register_evaluated_expr(
1420+
Expression::Compose {
1421+
ty,
1422+
components: vec![p0, p1, p2],
1423+
},
1424+
span,
1425+
)
1426+
}
1427+
1428+
/// Extract the values of a `vecN` from `expr`.
1429+
///
1430+
/// Return the value of `expr`, whose type is `vecN<S>` for some
1431+
/// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
1432+
/// values.
1433+
///
1434+
/// Also return the type handle from the `Compose` expression.
1435+
fn extract_vec<const N: usize>(
1436+
&mut self,
1437+
expr: Handle<Expression>,
1438+
) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1439+
let span = self.expressions.get_span(expr);
1440+
let expr = self.eval_zero_value_and_splat(expr, span)?;
1441+
let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1442+
return Err(ConstantEvaluatorError::InvalidMathArg);
1443+
};
1444+
1445+
let mut value = [Literal::Bool(false); N];
1446+
for (component, elt) in
1447+
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1448+
.zip(value.iter_mut())
1449+
{
1450+
let Expression::Literal(literal) = self.expressions[component] else {
1451+
return Err(ConstantEvaluatorError::InvalidMathArg);
1452+
};
1453+
*elt = literal;
1454+
}
1455+
1456+
Ok((value, ty))
1457+
}
1458+
13561459
fn array_length(
13571460
&mut self,
13581461
array: Handle<Expression>,
@@ -2689,6 +2792,19 @@ impl TryFromAbstract<i64> for f16 {
26892792
}
26902793
}
26912794

2795+
fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
2796+
where
2797+
T: Copy,
2798+
T: core::ops::Mul<T, Output = T>,
2799+
T: core::ops::Sub<T, Output = T>,
2800+
{
2801+
[
2802+
a[1] * b[2] - a[2] * b[1],
2803+
a[2] * b[0] - a[0] * b[2],
2804+
a[0] * b[1] - a[1] * b[0],
2805+
]
2806+
}
2807+
26922808
#[cfg(test)]
26932809
mod tests {
26942810
use alloc::{vec, vec::Vec};

naga/tests/in/wgsl/cross.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
// NOTE: invalid combinations are tested in the `validation::bad_cross_builtin_args` test.
22
@compute @workgroup_size(1) fn main() {
3-
let a = cross(vec3(0., 1., 2.), vec3(0., 1., 2.));
3+
let a = cross(vec3(1., 0., 0.), vec3(0., 1., 0.));
44
}

naga/tests/out/glsl/cross.main.Compute.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
77

88

99
void main() {
10-
vec3 a = cross(vec3(0.0, 1.0, 2.0), vec3(0.0, 1.0, 2.0));
10+
vec3 a = vec3(0.0, 0.0, 1.0);
1111
return;
1212
}
1313

naga/tests/out/hlsl/cross.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[numthreads(1, 1, 1)]
22
void main()
33
{
4-
float3 a = cross(float3(0.0, 1.0, 2.0), float3(0.0, 1.0, 2.0));
4+
float3 a = float3(0.0, 0.0, 1.0);
55
return;
66
}

naga/tests/out/msl/cross.msl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ using metal::uint;
77

88
kernel void main_(
99
) {
10-
metal::float3 a = metal::cross(metal::float3(0.0, 1.0, 2.0), metal::float3(0.0, 1.0, 2.0));
10+
metal::float3 a = metal::float3(0.0, 0.0, 1.0);
1111
return;
1212
}

naga/tests/out/spv/cross.spvasm

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; SPIR-V
22
; Version: 1.1
33
; Generator: rspirv
4-
; Bound: 14
4+
; Bound: 12
55
OpCapability Shader
66
%1 = OpExtInstImport "GLSL.std.450"
77
OpMemoryModel Logical GLSL450
@@ -13,12 +13,10 @@ OpExecutionMode %6 LocalSize 1 1 1
1313
%7 = OpTypeFunction %2
1414
%8 = OpConstant %4 0.0
1515
%9 = OpConstant %4 1.0
16-
%10 = OpConstant %4 2.0
17-
%11 = OpConstantComposite %3 %8 %9 %10
16+
%10 = OpConstantComposite %3 %8 %8 %9
1817
%6 = OpFunction %2 None %7
1918
%5 = OpLabel
20-
OpBranch %12
21-
%12 = OpLabel
22-
%13 = OpExtInst %3 %1 Cross %11 %11
19+
OpBranch %11
20+
%11 = OpLabel
2321
OpReturn
2422
OpFunctionEnd

naga/tests/out/wgsl/cross.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@compute @workgroup_size(1, 1, 1)
22
fn main() {
3-
let a = cross(vec3<f32>(0f, 1f, 2f), vec3<f32>(0f, 1f, 2f));
3+
const a = vec3<f32>(0f, 0f, 1f);
44
return;
55
}

naga/tests/validation.rs

Lines changed: 74 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -261,56 +261,86 @@ fn emit_workgroup_uniform_load_result() {
261261
assert!(variant(false).is_err());
262262
}
263263

264-
#[cfg(feature = "wgsl-in")]
265264
#[test]
266-
fn bad_cross_builtin_args() {
267-
// NOTE: Things we expect to actually compile are in the `cross` snapshot test.
268-
let cases = [
269-
(
270-
"vec2(0., 1.)",
271-
"\
272-
error: Entry point main at Compute is invalid
273-
┌─ wgsl:3:13
274-
275-
3 │ let a = cross(vec2(0., 1.), vec2(0., 1.));
276-
│ ^^^^^ naga::ir::Expression [6]
277-
278-
= Expression [6] is invalid
279-
= Argument [0] to Cross as expression [2] has an invalid type.
265+
fn builtin_cross_product_args() {
266+
use naga::{MathFunction, Module, Type, TypeInner, VectorSize};
280267

281-
",
282-
),
283-
(
284-
"vec4(0., 1., 2., 3.)",
285-
"\
286-
error: Entry point main at Compute is invalid
287-
┌─ wgsl:3:13
288-
289-
3 │ let a = cross(vec4(0., 1., 2., 3.), vec4(0., 1., 2., 3.));
290-
│ ^^^^^ naga::ir::Expression [10]
291-
292-
= Expression [10] is invalid
293-
= Argument [0] to Cross as expression [4] has an invalid type.
268+
// We want to ensure that the *only* problem with the code is the
269+
// arity of the vectors passed to `cross`. So validate two
270+
// versions of the module varying only in that aspect.
271+
//
272+
// Looking at uses of the `wg_load` makes it easy to identify the
273+
// differences between the two variants.
274+
fn variant(
275+
size: VectorSize,
276+
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
277+
let span = naga::Span::default();
278+
let mut module = Module::default();
279+
let ty_vec3f = module.types.insert(
280+
Type {
281+
name: Some("vecnf".into()),
282+
inner: TypeInner::Vector {
283+
size: VectorSize::Tri,
284+
scalar: Scalar::F32,
285+
},
286+
},
287+
span,
288+
);
289+
let ty_vecnf = module.types.insert(
290+
Type {
291+
name: Some("vecnf".into()),
292+
inner: TypeInner::Vector {
293+
size,
294+
scalar: Scalar::F32,
295+
},
296+
},
297+
span,
298+
);
294299

295-
",
296-
),
297-
];
300+
let mut fun = Function {
301+
result: Some(naga::ir::FunctionResult {
302+
ty: ty_vec3f,
303+
binding: None,
304+
}),
305+
..Function::default()
306+
};
307+
let ex_zero = fun
308+
.expressions
309+
.append(Expression::ZeroValue(ty_vecnf), span);
310+
let ex_cross = fun.expressions.append(
311+
Expression::Math {
312+
fun: MathFunction::Cross,
313+
arg: ex_zero,
314+
arg1: Some(ex_zero),
315+
arg2: None,
316+
arg3: None,
317+
},
318+
span,
319+
);
298320

299-
for (invalid_arg, expected_err) in cases {
300-
let source = format!(
301-
"\
302-
@compute @workgroup_size(1)
303-
fn main() {{
304-
let a = cross({invalid_arg}, {invalid_arg});
305-
}}
306-
"
321+
fun.body.push(
322+
naga::Statement::Emit(naga::Range::new_from_bounds(ex_cross, ex_cross)),
323+
span,
307324
);
308-
let module = naga::front::wgsl::parse_str(&source).unwrap();
309-
let err = valid::Validator::new(Default::default(), valid::Capabilities::all())
310-
.validate(&module)
311-
.expect_err("module should be invalid");
312-
assert_eq!(err.emit_to_string(&source), expected_err);
325+
fun.body.push(
326+
naga::Statement::Return {
327+
value: Some(ex_cross),
328+
},
329+
span,
330+
);
331+
332+
module.functions.append(fun, span);
333+
334+
valid::Validator::new(
335+
valid::ValidationFlags::default(),
336+
valid::Capabilities::all(),
337+
)
338+
.validate(&module)
313339
}
340+
341+
assert!(variant(VectorSize::Bi).is_err());
342+
variant(VectorSize::Tri).expect("module should validate");
343+
assert!(variant(VectorSize::Quad).is_err());
314344
}
315345

316346
#[cfg(feature = "wgsl-in")]

naga/tests/wgsl_errors.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ use naga::valid::Capabilities;
77

88
#[track_caller]
99
fn check(input: &str, snapshot: &str) {
10-
let output = naga::front::wgsl::parse_str(input)
11-
.map(|_| panic!("expected parser error, but parsing succeeded!"))
12-
.unwrap_err()
13-
.emit_to_string(input);
10+
let output = match naga::front::wgsl::parse_str(input) {
11+
Ok(_) => panic!("expected parser error, but parsing succeeded!"),
12+
Err(err) => err.emit_to_string(input),
13+
};
1414
if output != snapshot {
1515
for diff in diff::lines(snapshot, &output) {
1616
match diff {

0 commit comments

Comments
 (0)