Skip to content

Commit 4e9a2a5

Browse files
authored
[naga wgsl] Impl const_assert (#6198)
Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com>
1 parent ace2e20 commit 4e9a2a5

File tree

12 files changed

+299
-17
lines changed

12 files changed

+299
-17
lines changed

naga/src/front/wgsl/error.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ pub(crate) enum Error<'a> {
263263
limit: u8,
264264
},
265265
PipelineConstantIDValue(Span),
266+
NotBool(Span),
267+
ConstAssertFailed(Span),
266268
}
267269

268270
#[derive(Clone, Debug)]
@@ -815,6 +817,22 @@ impl<'a> Error<'a> {
815817
)],
816818
notes: vec![],
817819
},
820+
Error::NotBool(span) => ParseError {
821+
message: "must be a const-expression that resolves to a bool".to_string(),
822+
labels: vec![(
823+
span,
824+
"must resolve to bool".into(),
825+
)],
826+
notes: vec![],
827+
},
828+
Error::ConstAssertFailed(span) => ParseError {
829+
message: "const_assert failure".to_string(),
830+
labels: vec![(
831+
span,
832+
"evaluates to false".into(),
833+
)],
834+
notes: vec![],
835+
},
818836
}
819837
}
820838
}

naga/src/front/wgsl/index.rs

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@ impl<'a> Index<'a> {
2020
// While doing so, reject conflicting definitions.
2121
let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default());
2222
for (handle, decl) in tu.decls.iter() {
23-
let ident = decl_ident(decl);
24-
let name = ident.name;
25-
if let Some(old) = globals.insert(name, handle) {
26-
return Err(Error::Redefinition {
27-
previous: decl_ident(&tu.decls[old]).span,
28-
current: ident.span,
29-
});
23+
if let Some(ident) = decl_ident(decl) {
24+
let name = ident.name;
25+
if let Some(old) = globals.insert(name, handle) {
26+
return Err(Error::Redefinition {
27+
previous: decl_ident(&tu.decls[old])
28+
.expect("decl should have ident for redefinition")
29+
.span,
30+
current: ident.span,
31+
});
32+
}
3033
}
3134
}
3235

@@ -130,7 +133,7 @@ impl<'a> DependencySolver<'a, '_> {
130133
return if dep_id == id {
131134
// A declaration refers to itself directly.
132135
Err(Error::RecursiveDeclaration {
133-
ident: decl_ident(decl).span,
136+
ident: decl_ident(decl).expect("decl should have ident").span,
134137
usage: dep.usage,
135138
})
136139
} else {
@@ -146,14 +149,19 @@ impl<'a> DependencySolver<'a, '_> {
146149
.unwrap_or(0);
147150

148151
Err(Error::CyclicDeclaration {
149-
ident: decl_ident(&self.module.decls[dep_id]).span,
152+
ident: decl_ident(&self.module.decls[dep_id])
153+
.expect("decl should have ident")
154+
.span,
150155
path: self.path[start_at..]
151156
.iter()
152157
.map(|curr_dep| {
153158
let curr_id = curr_dep.decl;
154159
let curr_decl = &self.module.decls[curr_id];
155160

156-
(decl_ident(curr_decl).span, curr_dep.usage)
161+
(
162+
decl_ident(curr_decl).expect("decl should have ident").span,
163+
curr_dep.usage,
164+
)
157165
})
158166
.collect(),
159167
})
@@ -182,13 +190,14 @@ impl<'a> DependencySolver<'a, '_> {
182190
}
183191
}
184192

185-
const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
193+
const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> Option<ast::Ident<'a>> {
186194
match decl.kind {
187-
ast::GlobalDeclKind::Fn(ref f) => f.name,
188-
ast::GlobalDeclKind::Var(ref v) => v.name,
189-
ast::GlobalDeclKind::Const(ref c) => c.name,
190-
ast::GlobalDeclKind::Override(ref o) => o.name,
191-
ast::GlobalDeclKind::Struct(ref s) => s.name,
192-
ast::GlobalDeclKind::Type(ref t) => t.name,
195+
ast::GlobalDeclKind::Fn(ref f) => Some(f.name),
196+
ast::GlobalDeclKind::Var(ref v) => Some(v.name),
197+
ast::GlobalDeclKind::Const(ref c) => Some(c.name),
198+
ast::GlobalDeclKind::Override(ref o) => Some(o.name),
199+
ast::GlobalDeclKind::Struct(ref s) => Some(s.name),
200+
ast::GlobalDeclKind::Type(ref t) => Some(t.name),
201+
ast::GlobalDeclKind::ConstAssert(_) => None,
193202
}
194203
}

naga/src/front/wgsl/lower/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
12041204
ctx.globals
12051205
.insert(alias.name.name, LoweredGlobalDecl::Type(ty));
12061206
}
1207+
ast::GlobalDeclKind::ConstAssert(condition) => {
1208+
let condition = self.expression(condition, &mut ctx.as_const())?;
1209+
1210+
let span = ctx.module.global_expressions.get_span(condition);
1211+
match ctx
1212+
.module
1213+
.to_ctx()
1214+
.eval_expr_to_bool_from(condition, &ctx.module.global_expressions)
1215+
{
1216+
Some(true) => Ok(()),
1217+
Some(false) => Err(Error::ConstAssertFailed(span)),
1218+
_ => Err(Error::NotBool(span)),
1219+
}?;
1220+
}
12071221
}
12081222
}
12091223

@@ -1742,6 +1756,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
17421756
value,
17431757
}
17441758
}
1759+
ast::StatementKind::ConstAssert(condition) => {
1760+
let mut emitter = Emitter::default();
1761+
emitter.start(&ctx.function.expressions);
1762+
1763+
let condition =
1764+
self.expression(condition, &mut ctx.as_const(block, &mut emitter))?;
1765+
1766+
let span = ctx.function.expressions.get_span(condition);
1767+
match ctx
1768+
.module
1769+
.to_ctx()
1770+
.eval_expr_to_bool_from(condition, &ctx.function.expressions)
1771+
{
1772+
Some(true) => Ok(()),
1773+
Some(false) => Err(Error::ConstAssertFailed(span)),
1774+
_ => Err(Error::NotBool(span)),
1775+
}?;
1776+
1777+
block.extend(emitter.finish(&ctx.function.expressions));
1778+
1779+
return Ok(());
1780+
}
17451781
ast::StatementKind::Ignore(expr) => {
17461782
let mut emitter = Emitter::default();
17471783
emitter.start(&ctx.function.expressions);

naga/src/front/wgsl/parse/ast.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ pub enum GlobalDeclKind<'a> {
8585
Override(Override<'a>),
8686
Struct(Struct<'a>),
8787
Type(TypeAlias<'a>),
88+
ConstAssert(Handle<Expression<'a>>),
8889
}
8990

9091
#[derive(Debug)]
@@ -284,6 +285,7 @@ pub enum StatementKind<'a> {
284285
Increment(Handle<Expression<'a>>),
285286
Decrement(Handle<Expression<'a>>),
286287
Ignore(Handle<Expression<'a>>),
288+
ConstAssert(Handle<Expression<'a>>),
287289
}
288290

289291
#[derive(Debug)]

naga/src/front/wgsl/parse/mod.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,6 +2016,20 @@ impl Parser {
20162016
lexer.expect(Token::Separator(';'))?;
20172017
ast::StatementKind::Kill
20182018
}
2019+
// https://www.w3.org/TR/WGSL/#const-assert-statement
2020+
"const_assert" => {
2021+
let _ = lexer.next();
2022+
// parentheses are optional
2023+
let paren = lexer.skip(Token::Paren('('));
2024+
2025+
let condition = self.general_expression(lexer, ctx)?;
2026+
2027+
if paren {
2028+
lexer.expect(Token::Paren(')'))?;
2029+
}
2030+
lexer.expect(Token::Separator(';'))?;
2031+
ast::StatementKind::ConstAssert(condition)
2032+
}
20192033
// assignment or a function call
20202034
_ => {
20212035
self.function_call_or_assignment_statement(lexer, ctx, block)?;
@@ -2419,6 +2433,18 @@ impl Parser {
24192433
..function
24202434
}))
24212435
}
2436+
(Token::Word("const_assert"), _) => {
2437+
// parentheses are optional
2438+
let paren = lexer.skip(Token::Paren('('));
2439+
2440+
let condition = self.general_expression(lexer, &mut ctx)?;
2441+
2442+
if paren {
2443+
lexer.expect(Token::Paren(')'))?;
2444+
}
2445+
lexer.expect(Token::Separator(';'))?;
2446+
Some(ast::GlobalDeclKind::ConstAssert(condition))
2447+
}
24222448
(Token::End, _) => return Ok(()),
24232449
other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)),
24242450
};

naga/src/proc/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,19 @@ impl GlobalCtx<'_> {
674674
}
675675
}
676676

677+
/// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
678+
#[allow(dead_code)]
679+
pub(super) fn eval_expr_to_bool_from(
680+
&self,
681+
handle: crate::Handle<crate::Expression>,
682+
arena: &crate::Arena<crate::Expression>,
683+
) -> Option<bool> {
684+
match self.eval_expr_to_literal_from(handle, arena) {
685+
Some(crate::Literal::Bool(value)) => Some(value),
686+
_ => None,
687+
}
688+
}
689+
677690
#[allow(dead_code)]
678691
pub(crate) fn eval_expr_to_literal(
679692
&self,

naga/tests/in/const_assert.wgsl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Sourced from https://www.w3.org/TR/WGSL/#const-assert-statement
2+
const x = 1;
3+
const y = 2;
4+
const_assert x < y; // valid at module-scope.
5+
const_assert(y != 0); // parentheses are optional.
6+
7+
fn foo() {
8+
const z = x + y - 2;
9+
const_assert z > 0; // valid in functions.
10+
const_assert(z > 0);
11+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
(
2+
types: [
3+
(
4+
name: None,
5+
inner: Scalar((
6+
kind: Sint,
7+
width: 4,
8+
)),
9+
),
10+
],
11+
special_types: (
12+
ray_desc: None,
13+
ray_intersection: None,
14+
predeclared_types: {},
15+
),
16+
constants: [
17+
(
18+
name: Some("x"),
19+
ty: 0,
20+
init: 0,
21+
),
22+
(
23+
name: Some("y"),
24+
ty: 0,
25+
init: 1,
26+
),
27+
],
28+
overrides: [],
29+
global_variables: [],
30+
global_expressions: [
31+
Literal(I32(1)),
32+
Literal(I32(2)),
33+
],
34+
functions: [
35+
(
36+
name: Some("foo"),
37+
arguments: [],
38+
result: None,
39+
local_variables: [],
40+
expressions: [
41+
Literal(I32(1)),
42+
],
43+
named_expressions: {
44+
0: "z",
45+
},
46+
body: [
47+
Return(
48+
value: None,
49+
),
50+
],
51+
),
52+
],
53+
entry_points: [],
54+
)

naga/tests/out/ir/const_assert.ron

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
(
2+
types: [
3+
(
4+
name: None,
5+
inner: Scalar((
6+
kind: Sint,
7+
width: 4,
8+
)),
9+
),
10+
],
11+
special_types: (
12+
ray_desc: None,
13+
ray_intersection: None,
14+
predeclared_types: {},
15+
),
16+
constants: [
17+
(
18+
name: Some("x"),
19+
ty: 0,
20+
init: 0,
21+
),
22+
(
23+
name: Some("y"),
24+
ty: 0,
25+
init: 1,
26+
),
27+
],
28+
overrides: [],
29+
global_variables: [],
30+
global_expressions: [
31+
Literal(I32(1)),
32+
Literal(I32(2)),
33+
],
34+
functions: [
35+
(
36+
name: Some("foo"),
37+
arguments: [],
38+
result: None,
39+
local_variables: [],
40+
expressions: [
41+
Literal(I32(1)),
42+
],
43+
named_expressions: {
44+
0: "z",
45+
},
46+
body: [
47+
Return(
48+
value: None,
49+
),
50+
],
51+
),
52+
],
53+
entry_points: [],
54+
)

naga/tests/out/wgsl/const_assert.wgsl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
const x: i32 = 1i;
2+
const y: i32 = 2i;
3+
4+
fn foo() {
5+
return;
6+
}
7+

0 commit comments

Comments
 (0)