Skip to content

Commit c07fab2

Browse files
jamienicoljimblandy
authored andcommitted
[naga wgsl-in] Allow abstract literals to be used as return values
When lowering a return statement, call expression_for_abstract() rather than expression() to avoid concretizing the return value. Then, if the function has a return type, call try_automatic_conversions() to attempt to convert our return value to the correct type. This has the unfortunate side effect that some errors that would have been caught by the validator are instead encountered as conversion errors by the parser. This may result in a slightly less descriptive error message in some cases. (See the change to the invalid_functions() test, for example.)
1 parent 005bde9 commit c07fab2

11 files changed

+292
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).
7373
#### Naga
7474

7575
- Fix some instances of functions which have a return type but don't return a value being incorrectly validated. By @jamienicol in [#7013](https://github.com/gfx-rs/wgpu/pull/7013).
76+
- Allow abstract expressions to be used in WGSL function return statements. By @jamienicol in [#7035](https://github.com/gfx-rs/wgpu/pull/7035).
7677

7778
#### General
7879

naga/src/front/wgsl/lower/mod.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,13 +1672,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16721672
}
16731673
ast::StatementKind::Break => crate::Statement::Break,
16741674
ast::StatementKind::Continue => crate::Statement::Continue,
1675-
ast::StatementKind::Return { value } => {
1675+
ast::StatementKind::Return { value: ast_value } => {
16761676
let mut emitter = Emitter::default();
16771677
emitter.start(&ctx.function.expressions);
16781678

1679-
let value = value
1680-
.map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter)))
1681-
.transpose()?;
1679+
let value;
1680+
if let Some(ast_expr) = ast_value {
1681+
let result_ty = ctx.function.result.as_ref().map(|r| r.ty);
1682+
let mut ectx = ctx.as_expression(block, &mut emitter);
1683+
let expr = self.expression_for_abstract(ast_expr, &mut ectx)?;
1684+
1685+
if let Some(result_ty) = result_ty {
1686+
let mut ectx = ctx.as_expression(block, &mut emitter);
1687+
let resolution = crate::proc::TypeResolution::Handle(result_ty);
1688+
let converted =
1689+
ectx.try_automatic_conversions(expr, &resolution, Span::default())?;
1690+
value = Some(converted);
1691+
} else {
1692+
value = Some(expr);
1693+
}
1694+
} else {
1695+
value = None;
1696+
}
16821697
block.extend(emitter.finish(&ctx.function.expressions));
16831698

16841699
crate::Statement::Return { value }
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
@compute @workgroup_size(1)
2+
fn main() {}
3+
4+
fn return_i32_ai() -> i32 {
5+
return 1;
6+
}
7+
8+
fn return_u32_ai() -> u32 {
9+
return 1;
10+
}
11+
12+
fn return_f32_ai() -> f32 {
13+
return 1;
14+
}
15+
16+
fn return_f32_af() -> f32 {
17+
return 1.0;
18+
}
19+
20+
fn return_vec2f32_ai() -> vec2<f32> {
21+
return vec2(1);
22+
}
23+
24+
fn return_arrf32_ai() -> array<f32, 4> {
25+
return array(1, 1, 1, 1);
26+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#version 310 es
2+
3+
precision highp float;
4+
precision highp int;
5+
6+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
7+
8+
9+
int return_i32_ai() {
10+
return 1;
11+
}
12+
13+
uint return_u32_ai() {
14+
return 1u;
15+
}
16+
17+
float return_f32_ai() {
18+
return 1.0;
19+
}
20+
21+
float return_f32_af() {
22+
return 1.0;
23+
}
24+
25+
vec2 return_vec2f32_ai() {
26+
return vec2(1.0);
27+
}
28+
29+
float[4] return_arrf32_ai() {
30+
return float[4](1.0, 1.0, 1.0, 1.0);
31+
}
32+
33+
void main() {
34+
return;
35+
}
36+
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
int return_i32_ai()
2+
{
3+
return 1;
4+
}
5+
6+
uint return_u32_ai()
7+
{
8+
return 1u;
9+
}
10+
11+
float return_f32_ai()
12+
{
13+
return 1.0;
14+
}
15+
16+
float return_f32_af()
17+
{
18+
return 1.0;
19+
}
20+
21+
float2 return_vec2f32_ai()
22+
{
23+
return (1.0).xx;
24+
}
25+
26+
typedef float ret_Constructarray4_float_[4];
27+
ret_Constructarray4_float_ Constructarray4_float_(float arg0, float arg1, float arg2, float arg3) {
28+
float ret[4] = { arg0, arg1, arg2, arg3 };
29+
return ret;
30+
}
31+
32+
typedef float ret_return_arrf32_ai[4];
33+
ret_return_arrf32_ai return_arrf32_ai()
34+
{
35+
return Constructarray4_float_(1.0, 1.0, 1.0, 1.0);
36+
}
37+
38+
[numthreads(1, 1, 1)]
39+
void main()
40+
{
41+
return;
42+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
(
2+
vertex:[
3+
],
4+
fragment:[
5+
],
6+
compute:[
7+
(
8+
entry_point:"main",
9+
target_profile:"cs_5_1",
10+
),
11+
],
12+
)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// language: metal1.0
2+
#include <metal_stdlib>
3+
#include <simd/simd.h>
4+
5+
using metal::uint;
6+
7+
struct type_4 {
8+
float inner[4];
9+
};
10+
11+
int return_i32_ai(
12+
) {
13+
return 1;
14+
}
15+
16+
uint return_u32_ai(
17+
) {
18+
return 1u;
19+
}
20+
21+
float return_f32_ai(
22+
) {
23+
return 1.0;
24+
}
25+
26+
float return_f32_af(
27+
) {
28+
return 1.0;
29+
}
30+
31+
metal::float2 return_vec2f32_ai(
32+
) {
33+
return metal::float2(1.0);
34+
}
35+
36+
type_4 return_arrf32_ai(
37+
) {
38+
return type_4 {1.0, 1.0, 1.0, 1.0};
39+
}
40+
41+
kernel void main_(
42+
) {
43+
return;
44+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
; SPIR-V
2+
; Version: 1.1
3+
; Generator: rspirv
4+
; Bound: 41
5+
OpCapability Shader
6+
%1 = OpExtInstImport "GLSL.std.450"
7+
OpMemoryModel Logical GLSL450
8+
OpEntryPoint GLCompute %38 "main"
9+
OpExecutionMode %38 LocalSize 1 1 1
10+
OpDecorate %7 ArrayStride 4
11+
%2 = OpTypeVoid
12+
%3 = OpTypeInt 32 1
13+
%4 = OpTypeInt 32 0
14+
%5 = OpTypeFloat 32
15+
%6 = OpTypeVector %5 2
16+
%8 = OpConstant %4 4
17+
%7 = OpTypeArray %5 %8
18+
%11 = OpTypeFunction %3
19+
%12 = OpConstant %3 1
20+
%16 = OpTypeFunction %4
21+
%17 = OpConstant %4 1
22+
%21 = OpTypeFunction %5
23+
%22 = OpConstant %5 1.0
24+
%29 = OpTypeFunction %6
25+
%30 = OpConstantComposite %6 %22 %22
26+
%34 = OpTypeFunction %7
27+
%35 = OpConstantComposite %7 %22 %22 %22 %22
28+
%39 = OpTypeFunction %2
29+
%10 = OpFunction %3 None %11
30+
%9 = OpLabel
31+
OpBranch %13
32+
%13 = OpLabel
33+
OpReturnValue %12
34+
OpFunctionEnd
35+
%15 = OpFunction %4 None %16
36+
%14 = OpLabel
37+
OpBranch %18
38+
%18 = OpLabel
39+
OpReturnValue %17
40+
OpFunctionEnd
41+
%20 = OpFunction %5 None %21
42+
%19 = OpLabel
43+
OpBranch %23
44+
%23 = OpLabel
45+
OpReturnValue %22
46+
OpFunctionEnd
47+
%25 = OpFunction %5 None %21
48+
%24 = OpLabel
49+
OpBranch %26
50+
%26 = OpLabel
51+
OpReturnValue %22
52+
OpFunctionEnd
53+
%28 = OpFunction %6 None %29
54+
%27 = OpLabel
55+
OpBranch %31
56+
%31 = OpLabel
57+
OpReturnValue %30
58+
OpFunctionEnd
59+
%33 = OpFunction %7 None %34
60+
%32 = OpLabel
61+
OpBranch %36
62+
%36 = OpLabel
63+
OpReturnValue %35
64+
OpFunctionEnd
65+
%38 = OpFunction %2 None %39
66+
%37 = OpLabel
67+
OpBranch %40
68+
%40 = OpLabel
69+
OpReturn
70+
OpFunctionEnd
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
fn return_i32_ai() -> i32 {
2+
return 1i;
3+
}
4+
5+
fn return_u32_ai() -> u32 {
6+
return 1u;
7+
}
8+
9+
fn return_f32_ai() -> f32 {
10+
return 1f;
11+
}
12+
13+
fn return_f32_af() -> f32 {
14+
return 1f;
15+
}
16+
17+
fn return_vec2f32_ai() -> vec2<f32> {
18+
return vec2(1f);
19+
}
20+
21+
fn return_arrf32_ai() -> array<f32, 4> {
22+
return array<f32, 4>(1f, 1f, 1f, 1f);
23+
}
24+
25+
@compute @workgroup_size(1, 1, 1)
26+
fn main() {
27+
return;
28+
}

naga/tests/snapshots.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,10 @@ fn convert_wgsl() {
921921
"abstract-types-operators",
922922
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL,
923923
),
924+
(
925+
"abstract-types-return",
926+
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
927+
),
924928
(
925929
"int64",
926930
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,

0 commit comments

Comments
 (0)