From e7d0b6d9bf1ca05942e167e93fcf1eb6348ff144 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Mon, 7 Jul 2025 08:08:04 +0300 Subject: [PATCH 1/6] compiletests: future-proof `issue-46` against MIR optimizations. --- tests/compiletests/ui/lang/issue-46.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/compiletests/ui/lang/issue-46.rs b/tests/compiletests/ui/lang/issue-46.rs index 99a0835a95..fc64dd01d7 100644 --- a/tests/compiletests/ui/lang/issue-46.rs +++ b/tests/compiletests/ui/lang/issue-46.rs @@ -2,7 +2,7 @@ use spirv_std::spirv; -#[derive(Default)] +#[derive(Copy, Clone, Default)] struct Foo { bar: bool, baz: [[u32; 2]; 1], @@ -13,3 +13,20 @@ pub fn main() { let x = [[1; 2]; 1]; let y = [Foo::default(); 1]; } + +// HACK(eddyb) future-proofing against `[expr; 1]` -> `[expr]` +// MIR optimization (https://github.com/rust-lang/rust/pull/135322). +fn force_repeat_one() -> ([[u32; 2]; ONE], [Foo; ONE]) { + ([[1; 2]; ONE], [Foo::default(); ONE]) +} + +#[spirv(fragment)] +pub fn main_future_proof( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut [[u32; 2]; 2], +) { + let (x, y) = force_repeat_one::<1>(); + + // NOTE(eddyb) further guard against optimizations by using `x` and `y`. + out[0] = x[0]; + out[1] = y[0].baz[0]; +} From c2ebd1aaa5ef6c0bf18ed68db8bcbfa8a8eedff7 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Wed, 7 May 2025 19:42:00 +0300 Subject: [PATCH 2/6] builder: always `pointercast` first before attempting to merge `OpAccessChain`s. --- .../src/builder/builder_methods.rs | 56 ++++++++++++------- .../ui/dis/panic_builtin_bounds_check.stderr | 50 +++++++++++------ 2 files changed, 70 insertions(+), 36 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 5521aa973c..cae01bd8e5 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -854,12 +854,34 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // --- End Recovery Path --- + // FIXME(eddyb) the comments below might not make much sense, because this + // used to be in the "fallback path" before being moved to before merging. + // + // Before emitting the AccessChain, explicitly cast the base pointer `ptr` to + // ensure its pointee type matches the input `ty`. This is required because the + // SPIR-V `AccessChain` instruction implicitly uses the size of the base + // pointer's pointee type when applying the *first* index operand (our + // `ptr_base_index`). If `ty` and `original_pointee_ty` mismatched and we + // reached this fallback, this cast ensures SPIR-V validity. + trace!("maybe_inbounds_gep fallback path calling pointercast"); + // Cast ptr to point to `ty`. + // HACK(eddyb) temporary workaround for untyped pointers upstream. + // FIXME(eddyb) replace with untyped memory SPIR-V + `qptr` or similar. + let ptr = self.pointercast(ptr, self.type_ptr_to(ty)); + // Get the ID of the (potentially newly casted) pointer. + let ptr_id = ptr.def(self); + // HACK(eddyb) updated pointee type of `ptr` post-`pointercast`. + let original_pointee_ty = ty; + // --- Attempt GEP Merging Path --- // Check if the base pointer `ptr` itself was the result of a previous // AccessChain instruction. Merging is only attempted if the input type `ty` // matches the pointer's actual underlying pointee type `original_pointee_ty`. // If they differ, merging could be invalid. + // HACK(eddyb) always attempted now, because we `pointercast` first, which: + // - is noop when `ty == original_pointee_ty` pre-`pointercast` (old condition) + // - may generate (potentially mergeable) new `AccessChain`s in other cases let maybe_original_access_chain = if ty == original_pointee_ty { // Search the current function's instructions... // FIXME(eddyb) this could get ridiculously expensive, at the very least @@ -908,12 +930,21 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // 2. The *last* index of the original AccessChain is a constant. // 3. The *first* index (`ptr_base_index`) of the *current* GEP is a constant. // Merging usually involves adding these two constant indices. + // + // FIXME(eddyb) the above comment seems inaccurate, there is no reason + // why runtime indices couldn't be added together just like constants + // (and in fact this is needed nowadays for all array indexing). let can_merge = if let Some(&last_original_idx_id) = original_indices.last() { - // Check if both the last original index and the current base index are constant scalars. - self.builder - .lookup_const_scalar(last_original_idx_id.with_type(ptr_base_index.ty)) - .is_some() - && self.builder.lookup_const_scalar(ptr_base_index).is_some() + // HACK(eddyb) see the above comment, this bypasses the const + // check below, without tripping a clippy warning etc. + let always_merge = true; + always_merge || { + // Check if both the last original index and the current base index are constant scalars. + self.builder + .lookup_const_scalar(last_original_idx_id.with_type(ptr_base_index.ty)) + .is_some() + && self.builder.lookup_const_scalar(ptr_base_index).is_some() + } } else { // Original access chain had no indices to merge with. false @@ -966,21 +997,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // This path is taken if neither the Recovery nor the Merging path succeeded or applied. // It performs a more direct translation of the GEP request. - // HACK(eddyb): Workaround for potential upstream issues where pointers might lack precise type info. - // FIXME(eddyb): Ideally, this should use untyped memory features if available/necessary. - - // Before emitting the AccessChain, explicitly cast the base pointer `ptr` to - // ensure its pointee type matches the input `ty`. This is required because the - // SPIR-V `AccessChain` instruction implicitly uses the size of the base - // pointer's pointee type when applying the *first* index operand (our - // `ptr_base_index`). If `ty` and `original_pointee_ty` mismatched and we - // reached this fallback, this cast ensures SPIR-V validity. - trace!("maybe_inbounds_gep fallback path calling pointercast"); - // Cast ptr to point to `ty`. - let ptr = self.pointercast(ptr, self.type_ptr_to(ty)); - // Get the ID of the (potentially newly casted) pointer. - let ptr_id = ptr.def(self); - trace!( "emitting access chain via fallback path with pointer type: {}", self.debug_type(final_spirv_ptr_type) diff --git a/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr b/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr index d2066bab76..5a0a23a6fb 100644 --- a/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr +++ b/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr @@ -6,27 +6,45 @@ OpEntryPoint Fragment %2 "main" OpExecutionMode %2 OriginUpperLeft %3 = OpString "/n[Rust panicked at $DIR/panic_builtin_bounds_check.rs:25:5]/n index out of bounds: the len is %u but the index is %u/n in main()/n" %4 = OpString "$DIR/panic_builtin_bounds_check.rs" -%5 = OpTypeVoid -%6 = OpTypeFunction %5 -%7 = OpTypeBool +OpDecorate %5 ArrayStride 4 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 %8 = OpTypeInt 32 0 -%9 = OpConstant %8 5 -%10 = OpConstant %8 4 -%11 = OpUndef %8 -%2 = OpFunction %5 None %6 -%12 = OpLabel +%9 = OpConstant %8 4 +%5 = OpTypeArray %8 %9 +%10 = OpTypePointer Function %5 +%11 = OpConstant %8 0 +%12 = OpConstant %8 1 +%13 = OpConstant %8 2 +%14 = OpConstant %8 3 +%15 = OpTypeBool +%16 = OpConstant %8 5 +%17 = OpUndef %8 +%18 = OpTypePointer Function %8 +%2 = OpFunction %6 None %7 +%19 = OpLabel +OpLine %4 30 4 +%20 = OpVariable %10 Function +OpLine %4 30 23 +%21 = OpCompositeConstruct %5 %11 %12 %13 %14 OpLine %4 25 4 -%13 = OpULessThan %7 %9 %10 +OpStore %20 %21 +%22 = OpULessThan %15 %16 %9 OpNoLine -OpSelectionMerge %14 None -OpBranchConditional %13 %15 %16 -%15 = OpLabel -OpBranch %14 -%16 = OpLabel +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %25 +%24 = OpLabel +OpBranch %23 +%25 = OpLabel OpLine %4 25 4 -%17 = OpExtInst %5 %1 1 %3 %11 %9 +%26 = OpExtInst %6 %1 1 %3 %17 %16 OpNoLine OpReturn -%14 = OpLabel +%23 = OpLabel +OpLine %4 25 4 +%27 = OpIAdd %8 %11 %16 +%28 = OpInBoundsAccessChain %18 %20 %27 +%29 = OpLoad %8 %28 +OpNoLine OpReturn OpFunctionEnd From dd767a6061e9c492124edd976e93b5b28657711b Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Mon, 7 Jul 2025 08:21:57 +0300 Subject: [PATCH 3/6] builder: backport `rust-lang/rust#134117` (`[0, i]` -> `[i]` for array GEPs). --- .../src/builder/builder_methods.rs | 77 +++++++++++++++++++ tests/compiletests/ui/lang/issue-46.stderr | 8 ++ .../storage_class/typed-buffer-simple.stderr | 8 ++ 3 files changed, 93 insertions(+) create mode 100644 tests/compiletests/ui/lang/issue-46.stderr create mode 100644 tests/compiletests/ui/storage_class/typed-buffer-simple.stderr diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index cae01bd8e5..e62a48390b 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -655,6 +655,68 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // https://github.com/gpuweb/gpuweb/issues/33 let (&ptr_base_index, indices) = combined_indices.split_first().unwrap(); + // HACK(eddyb) this effectively removes any real support for GEPs with + // any `indices` (beyond `ptr_base_index`), which should now be the case + // across `rustc_codegen_ssa` (see also comment inside `inbounds_gep`). + // FIXME(eddyb) are the warning + fallback path even work keeping? + if !indices.is_empty() { + // HACK(eddyb) Cargo silences warnings in dependencies. + let force_warn = |span, msg| -> rustc_errors::Diag<'_, ()> { + rustc_errors::Diag::new( + self.tcx.dcx(), + rustc_errors::Level::ForceWarning(None), + msg, + ) + .with_span(span) + }; + force_warn( + self.span(), + format!( + "[RUST-GPU BUG] `inbounds_gep` or `gep` called with \ + {} combined indices (expected only 1)", + combined_indices.len(), + ), + ) + .emit(); + + let indexed_base_ptr = self.maybe_inbounds_gep(ty, ptr, &[ptr_base_index], is_inbounds); + let indexed_base_ptr_id = indexed_base_ptr.def(self); + assert_ty_eq!(self, indexed_base_ptr.ty, self.type_ptr_to(ty)); + + let mut final_pointee = ty; + for &index in indices { + final_pointee = match self.lookup_type(final_pointee) { + SpirvType::Adt { field_types, .. } => { + field_types[self + .builder + .lookup_const_scalar(index) + .expect("non-const struct index for GEP") + as usize] + } + SpirvType::Array { element, .. } + | SpirvType::RuntimeArray { element } + | SpirvType::Vector { element, .. } + | SpirvType::Matrix { element, .. } => element, + + _ => self.fatal(format!( + "GEP not implemented for indexing into type {}", + self.debug_type(final_pointee) + )), + }; + } + let final_spirv_ptr_type = self.type_ptr_to(final_pointee); + + let indices_ids: Vec<_> = indices.iter().map(|index| index.def(self)).collect(); + + return self.emit_access_chain( + final_spirv_ptr_type, + indexed_base_ptr_id, + None, + indices_ids, + is_inbounds, + ); + } + // Determine if this GEP operation is effectively byte-level addressing. // This check is based on the *provided* input type `ty`. If `ty` is i8 or u8, // it suggests the caller intends to perform byte-offset calculations, @@ -1972,6 +2034,21 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { ptr: Self::Value, indices: &[Self::Value], ) -> Self::Value { + // HACK(eddyb) effectively a backport of this `gep [0, i]` -> `gep [i]` + // PR: https://github.com/rust-lang/rust/pull/134117 to even earlier + // nightlies - and that PR happens to remove the last GEP that can be + // emitted with any "structured" (struct/array) indices, beyond the + // "first index" (which acts as `<*T>::offset` aka "pointer arithmetic"). + if let &[ptr_base_index, structured_index] = indices { + if self.builder.lookup_const_scalar(ptr_base_index) == Some(0) { + if let SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element, .. } = + self.lookup_type(ty) + { + return self.maybe_inbounds_gep(element, ptr, &[structured_index], true); + } + } + } + self.maybe_inbounds_gep(ty, ptr, indices, true) } diff --git a/tests/compiletests/ui/lang/issue-46.stderr b/tests/compiletests/ui/lang/issue-46.stderr new file mode 100644 index 0000000000..baaa29572b --- /dev/null +++ b/tests/compiletests/ui/lang/issue-46.stderr @@ -0,0 +1,8 @@ +error: error:0:0 - Result type (OpTypeArray) does not match the type that results from indexing into the composite (OpTypeArray). + %67 = OpCompositeExtract %_arr_uint_uint_2 %49 0 0 + | + = note: spirv-val failed + = note: module `$TEST_BUILD_DIR/lang/issue-46.spv1.3` + +error: aborting due to 1 previous error + diff --git a/tests/compiletests/ui/storage_class/typed-buffer-simple.stderr b/tests/compiletests/ui/storage_class/typed-buffer-simple.stderr new file mode 100644 index 0000000000..f46577983d --- /dev/null +++ b/tests/compiletests/ui/storage_class/typed-buffer-simple.stderr @@ -0,0 +1,8 @@ +error: error:0:0 - OpInBoundsAccessChain result type (OpTypeInt) does not match the type that results from indexing into the base (OpTypeArray). + %33 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %28 %uint_0 + | + = note: spirv-val failed + = note: module `$TEST_BUILD_DIR/storage_class/typed-buffer-simple.spv1.3` + +error: aborting due to 1 previous error + From 72a38c84c92f5a41594b4d33e42366df70511502 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Mon, 7 Jul 2025 08:58:45 +0300 Subject: [PATCH 4/6] builder: special-case `0`-byte-offset GEPs as pointer casts. --- crates/rustc_codegen_spirv/src/builder/builder_methods.rs | 8 ++++++++ tests/compiletests/ui/lang/issue-46.stderr | 8 -------- .../ui/storage_class/typed-buffer-simple.stderr | 8 -------- 3 files changed, 8 insertions(+), 16 deletions(-) delete mode 100644 tests/compiletests/ui/lang/issue-46.stderr delete mode 100644 tests/compiletests/ui/storage_class/typed-buffer-simple.stderr diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index e62a48390b..40dd9a502f 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -817,6 +817,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // If we successfully calculated a constant byte offset for the first index... if let Some(const_ptr_offset_bytes) = const_ptr_offset_bytes { + // HACK(eddyb) an offset of `0` is always a noop, and `pointercast` + // gets to use `SpirvValueKind::LogicalPtrCast`, which can later + // be "undone" (by `strip_ptrcasts`), allowing more flexibility + // downstream (instead of overeagerly "shrinking" the pointee). + if const_ptr_offset_bytes == 0 { + return self.pointercast(ptr, final_spirv_ptr_type); + } + // Try to reconstruct a more "structured" access chain based on the *original* // pointee type of the pointer (`original_pointee_ty`) and the calculated byte offset. // This is useful if the input `ty` was generic (like u8) but the pointer actually diff --git a/tests/compiletests/ui/lang/issue-46.stderr b/tests/compiletests/ui/lang/issue-46.stderr deleted file mode 100644 index baaa29572b..0000000000 --- a/tests/compiletests/ui/lang/issue-46.stderr +++ /dev/null @@ -1,8 +0,0 @@ -error: error:0:0 - Result type (OpTypeArray) does not match the type that results from indexing into the composite (OpTypeArray). - %67 = OpCompositeExtract %_arr_uint_uint_2 %49 0 0 - | - = note: spirv-val failed - = note: module `$TEST_BUILD_DIR/lang/issue-46.spv1.3` - -error: aborting due to 1 previous error - diff --git a/tests/compiletests/ui/storage_class/typed-buffer-simple.stderr b/tests/compiletests/ui/storage_class/typed-buffer-simple.stderr deleted file mode 100644 index f46577983d..0000000000 --- a/tests/compiletests/ui/storage_class/typed-buffer-simple.stderr +++ /dev/null @@ -1,8 +0,0 @@ -error: error:0:0 - OpInBoundsAccessChain result type (OpTypeInt) does not match the type that results from indexing into the base (OpTypeArray). - %33 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %28 %uint_0 - | - = note: spirv-val failed - = note: module `$TEST_BUILD_DIR/storage_class/typed-buffer-simple.spv1.3` - -error: aborting due to 1 previous error - From b1f086105dbe5719391f5b563529f7ca4c640879 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Mon, 7 Jul 2025 09:11:40 +0300 Subject: [PATCH 5/6] builder: move GEP handling into a simpler `ptr_offset_strided` method. --- .../src/builder/builder_methods.rs | 473 +++++------------- 1 file changed, 132 insertions(+), 341 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 40dd9a502f..a19b462440 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -620,13 +620,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { && leaf_size_range.contains(&ty_size) && leaf_ty.map_or(true, |leaf_ty| leaf_ty == ty) { - trace!("returning type: {:?}", self.debug_type(ty)); - trace!("returning indices with len: {:?}", indices.len()); + trace!("successful recovery leaf type: {:?}", self.debug_type(ty)); + trace!("successful recovery indices: {:?}", indices); return Some((indices, ty)); } } } + // NOTE(eddyb) see `ptr_offset_strided`, which this now forwards to. #[instrument(level = "trace", skip(self), fields(ty = ?self.debug_type(ty), ptr, combined_indices = ?combined_indices.iter().map(|x| (self.debug_type(x.ty), x.kind)).collect::>(), is_inbounds))] fn maybe_inbounds_gep( &mut self, @@ -655,6 +656,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // https://github.com/gpuweb/gpuweb/issues/33 let (&ptr_base_index, indices) = combined_indices.split_first().unwrap(); + let indexed_base_ptr = self.ptr_offset_strided(ptr, ty, ptr_base_index, is_inbounds); + // HACK(eddyb) this effectively removes any real support for GEPs with // any `indices` (beyond `ptr_base_index`), which should now be the case // across `rustc_codegen_ssa` (see also comment inside `inbounds_gep`). @@ -679,7 +682,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ) .emit(); - let indexed_base_ptr = self.maybe_inbounds_gep(ty, ptr, &[ptr_base_index], is_inbounds); let indexed_base_ptr_id = indexed_base_ptr.def(self); assert_ty_eq!(self, indexed_base_ptr.ty, self.type_ptr_to(ty)); @@ -717,368 +719,157 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ); } - // Determine if this GEP operation is effectively byte-level addressing. - // This check is based on the *provided* input type `ty`. If `ty` is i8 or u8, - // it suggests the caller intends to perform byte-offset calculations, - // which might allow for more flexible type recovery later. - let is_byte_gep = matches!(self.lookup_type(ty), SpirvType::Integer(8, _)); - trace!("Is byte GEP (based on input type): {}", is_byte_gep); - - // --- Calculate the final pointee type based on the GEP operation --- - - // This loop does the type traversal according to the `indices` (excluding the - // base offset index). It starts with the initial element type `ty` and - // iteratively applies each index to determine the type of the element being - // accessed at each step. The result is the type that the *final* pointer, - // generated by the SPIR-V `AccessChain`` instruction, *must* point to according - // to the SPIR-V specification and the provided `indices`. - let mut calculated_pointee_type = ty; - for index_val in indices { - // Lookup the current aggregate type we are indexing into. - calculated_pointee_type = match self.lookup_type(calculated_pointee_type) { - // If it's a struct (ADT), the index must be a constant. Use it to get - // the field type. - SpirvType::Adt { field_types, .. } => { - let const_index = self - .builder - .lookup_const_scalar(*index_val) - .expect("Non-constant struct index for GEP") - as usize; - // Get the type of the specific field. - field_types[const_index] - } - // If it's an array, vector, or matrix, indexing yields the element type. - SpirvType::Array { element, .. } - | SpirvType::RuntimeArray { element } - | SpirvType::Vector { element, .. } - | SpirvType::Matrix { element, .. } => element, - // Special case: If we started with a byte GEP (`is_byte_gep` is true) and - // we are currently indexing into a byte type, the result is still a byte type. - // This prevents errors if `indices` contains non-zero values when `ty` is u8/i8. - SpirvType::Integer(8, signedness) if is_byte_gep => { - // Define the resulting byte type as it might not exist yet). - SpirvType::Integer(8, signedness).def(self.span(), self) - } - // Any other type cannot be indexed into via GEP. - _ => self.fatal(format!( - "GEP not implemented for indexing into type {}", - self.debug_type(calculated_pointee_type) - )), - }; - } - // Construct the SPIR-V pointer type that points to the final calculated pointee - // type. This is the *required* result type for the SPIR-V `AccessChain` - // instruction. - let final_spirv_ptr_type = self.type_ptr_to(calculated_pointee_type); - trace!( - "Calculated final SPIR-V pointee type: {}", - self.debug_type(calculated_pointee_type) - ); - trace!( - "Calculated final SPIR-V ptr type: {}", - self.debug_type(final_spirv_ptr_type) - ); - - // Ensure all the `indices` (excluding the base offset index) are defined in the - // SPIR-V module and get their corresponding SPIR-V IDs. These IDs will be used - // as operands in the AccessChain instruction. - let gep_indices_ids: Vec<_> = indices.iter().map(|index| index.def(self)).collect(); - - // --- Prepare the base pointer --- - - // Remove any potentially redundant pointer casts applied to the input `ptr`. - // GEP operations should ideally work on the "underlying" pointer. - let ptr = ptr.strip_ptrcasts(); - // Get the SPIR-V ID for the (potentially stripped) base pointer. - let ptr_id = ptr.def(self); - // Determine the actual pointee type of the base pointer `ptr` *after* stripping casts. - // This might differ from the input `ty` if `ty` was less specific (e.g., u8). - let original_pointee_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { pointee } => pointee, - other => self.fatal(format!("gep called on non-pointer type: {other:?}")), - }; + indexed_base_ptr + } - // --- Recovery Path --- + /// Array-indexing-like pointer arithmetic, i.e. `(ptr: *T).offset(index)` + /// (or `wrapping_offset` instead of `offset`, for `is_inbounds = false`), + /// where `T` is given by `stride_elem_ty` (named so for extra clarity). + /// + /// This can produce legal SPIR-V by using 3 strategies: + /// 1. `pointercast` for a constant offset of `0` + /// - itself can succeed via `recover_access_chain_from_offset` + /// - even if a specific cast is unsupported, legal SPIR-V can still be + /// obtained, if all downstream uses rely on e.g. `strip_ptrcasts` + /// - also used before the merge strategy (3.), to improve its chances + /// 2. `recover_access_chain_from_offset` for constant offsets + /// (e.g. from `ptradd`/`inbounds_ptradd` used to access `struct` fields) + /// 3. merging onto an array `OpAccessChain` with the same `stride_elem_ty` + /// (possibly `&array[0]` from `pointercast` doing `*[T; N]` -> `*T`) + #[instrument(level = "trace", skip(self), fields(ptr, stride_elem_ty = ?self.debug_type(stride_elem_ty), index, is_inbounds))] + fn ptr_offset_strided( + &mut self, + ptr: SpirvValue, + stride_elem_ty: Word, + index: SpirvValue, + is_inbounds: bool, + ) -> SpirvValue { + // Precompute a constant `index * stride` (i.e. effective pointer offset) + // if possible, as both strategies 1 and 2 rely on knowing this value. + let const_offset = self.builder.lookup_const_scalar(index).and_then(|idx| { + let idx_u64 = u64::try_from(idx).ok()?; + let stride = self.lookup_type(stride_elem_ty).sizeof(self)?; + Some(idx_u64 * stride) + }); - // Try to calculate the byte offset implied by the *first* index - // (`ptr_base_index`) if it's a compile-time constant. This uses the size of the - // *input type* `ty`. - let const_ptr_offset_bytes = self - .builder - .lookup_const_scalar(ptr_base_index) // Check if ptr_base_index is constant scalar - .and_then(|idx| { - let idx_u64 = u64::try_from(idx).ok()?; - // Get the size of the input type `ty` - self.lookup_type(ty) - .sizeof(self) - // Calculate offset in bytes - .map(|size| idx_u64.saturating_mul(size.bytes())) - }); - - // If we successfully calculated a constant byte offset for the first index... - if let Some(const_ptr_offset_bytes) = const_ptr_offset_bytes { - // HACK(eddyb) an offset of `0` is always a noop, and `pointercast` - // gets to use `SpirvValueKind::LogicalPtrCast`, which can later - // be "undone" (by `strip_ptrcasts`), allowing more flexibility - // downstream (instead of overeagerly "shrinking" the pointee). - if const_ptr_offset_bytes == 0 { - return self.pointercast(ptr, final_spirv_ptr_type); - } + // Strategy 1: an offset of `0` is always a noop, and `pointercast` + // gets to use `SpirvValueKind::LogicalPtrCast`, which can later + // be "undone" (by `strip_ptrcasts`), allowing more flexibility + // downstream (instead of overeagerly "shrinking" the pointee). + if const_offset == Some(Size::ZERO) { + trace!("ptr_offset_strided: strategy 1 picked: offset 0 => pointer cast"); + + // FIXME(eddyb) could this just `return ptr;`? what even breaks? + return self.pointercast(ptr, self.type_ptr_to(stride_elem_ty)); + } + + // Strategy 2: try recovering an `OpAccessChain` from a constant offset. + if let Some(const_offset) = const_offset { + // Remove any (redundant) pointer casts applied to the input `ptr`, + // to obtain the "most original" pointer (which ideally will be e.g. + // a whole `OpVariable`, or the result of a previous `OpAccessChain`). + let original_ptr = ptr.strip_ptrcasts(); + let original_pointee_ty = match self.lookup_type(original_ptr.ty) { + SpirvType::Pointer { pointee } => pointee, + other => self.fatal(format!("pointer arithmetic on non-pointer type {other:?}")), + }; - // Try to reconstruct a more "structured" access chain based on the *original* - // pointee type of the pointer (`original_pointee_ty`) and the calculated byte offset. - // This is useful if the input `ty` was generic (like u8) but the pointer actually - // points to a structured type (like a struct). `recover_access_chain_from_offset` - // attempts to find a sequence of constant indices (`base_indices`) into - // `original_pointee_ty` that matches the `const_ptr_offset_bytes`. - if let Some((base_indices, base_pointee_ty)) = self.recover_access_chain_from_offset( - // Start from the pointer's actual underlying type + if let Some((const_indices, leaf_pointee_ty)) = self.recover_access_chain_from_offset( original_pointee_ty, - // The target byte offset - Size::from_bytes(const_ptr_offset_bytes), - // Allowed range (not strictly needed here?) + const_offset, Some(Size::ZERO)..=None, - // Don't require alignment None, ) { - // Recovery successful! Found a structured path (`base_indices`) to the target offset. trace!( - "`recover_access_chain_from_offset` returned Some with base_pointee_ty: {}", - self.debug_type(base_pointee_ty) + "ptr_offset_strided: strategy 2 picked: offset {const_offset:?} \ + => access chain w/ {const_indices:?}" ); - // Determine the result type for the `AccessChain` instruction we might - // emit. By default, use the `final_spirv_ptr_type` strictly calculated - // earlier from `ty` and `indices`. - // - // If this is a byte GEP *and* the recovered type happens to be a byte - // type, we can use the pointer type derived from the *recovered* type - // (`base_pointee_ty`). This helps preserve type information when - // recovery works for byte addressing. - let result_wrapper_type = if !is_byte_gep - || matches!(self.lookup_type(base_pointee_ty), SpirvType::Integer(8, _)) - { - trace!( - "Using strictly calculated type for wrapper: {}", - // Use type based on input `ty` + `indices` - self.debug_type(calculated_pointee_type) - ); - final_spirv_ptr_type - } else { - trace!( - "Byte GEP allowing recovered type for wrapper: {}", - // Use type based on recovery result - self.debug_type(base_pointee_ty) - ); - self.type_ptr_to(base_pointee_ty) - }; - - // Check if we can directly use the recovered path combined with the - // remaining indices. This is possible if: - // 1. The input type `ty` matches the type found by recovery - // (`base_pointee_ty`). This means the recovery didn't fundamentally - // change the type interpretation needed for the *next* steps - // (`indices`). - // OR - // 2. There are no further indices (`gep_indices_ids` is empty). In this - // case, the recovery path already leads to the final destination. - if ty == base_pointee_ty || gep_indices_ids.is_empty() { - // Combine the recovered constant indices with the remaining dynamic/constant indices. - let combined_indices = base_indices - .into_iter() - // Convert recovered `u32` indices to constant SPIR-V IDs. - .map(|idx| self.constant_u32(self.span(), idx).def(self)) - // Chain the original subsequent indices (`indices`). - .chain(gep_indices_ids.iter().copied()) - .collect(); + let leaf_ptr_ty = self.type_ptr_to(leaf_pointee_ty); + let original_ptr_id = original_ptr.def(self); + let const_indices_ids = const_indices + .into_iter() + .map(|idx| self.constant_u32(self.span(), idx).def(self)) + .collect(); - trace!( - "emitting access chain via recovery path with wrapper type: {}", - self.debug_type(result_wrapper_type) - ); - // Emit a single AccessChain using the original pointer `ptr_id` and the fully combined index list. - // Note: We don't pass `ptr_base_index` here because its effect is incorporated into `base_indices`. - return self.emit_access_chain( - result_wrapper_type, // The chosen result pointer type - ptr_id, // The original base pointer ID - None, // No separate base index needed - combined_indices, // The combined structured + original indices - is_inbounds, // Preserve original inbounds request - ); - } else { - // Recovery happened, but the recovered type `base_pointee_ty` doesn't match the input `ty`, - // AND there are more `indices` to process. Using the `base_indices` derived from - // `original_pointee_ty` would be incorrect for interpreting the subsequent `indices` - // which were intended to operate relative to `ty`. Fall back to the standard path. - trace!( - "Recovery type mismatch ({}) vs ({}) and GEP indices exist, falling back", - self.debug_type(ty), - self.debug_type(base_pointee_ty) - ); - } - } else { - // `recover_access_chain_from_offset` couldn't find a structured path for the constant offset. - trace!("`recover_access_chain_from_offset` returned None, falling back"); + return self.emit_access_chain( + leaf_ptr_ty, + original_ptr_id, + None, + const_indices_ids, + is_inbounds, + ); } } - // --- End Recovery Path --- - - // FIXME(eddyb) the comments below might not make much sense, because this - // used to be in the "fallback path" before being moved to before merging. - // - // Before emitting the AccessChain, explicitly cast the base pointer `ptr` to - // ensure its pointee type matches the input `ty`. This is required because the - // SPIR-V `AccessChain` instruction implicitly uses the size of the base - // pointer's pointee type when applying the *first* index operand (our - // `ptr_base_index`). If `ty` and `original_pointee_ty` mismatched and we - // reached this fallback, this cast ensures SPIR-V validity. - trace!("maybe_inbounds_gep fallback path calling pointercast"); - // Cast ptr to point to `ty`. - // HACK(eddyb) temporary workaround for untyped pointers upstream. - // FIXME(eddyb) replace with untyped memory SPIR-V + `qptr` or similar. - let ptr = self.pointercast(ptr, self.type_ptr_to(ty)); - // Get the ID of the (potentially newly casted) pointer. + // Strategy 3: try merging onto an `OpAccessChain` of matching type, + // and this starts by `pointercast`-ing to that type for two reasons: + // - this is the only type a merged `OpAccessChain` could produce + // - `pointercast` can itself produce a new `OpAccessChain` in the + // right circumstances (e.g. `&array[0]` for `*[T; N]` -> `*T`) + let ptr = self.pointercast(ptr, self.type_ptr_to(stride_elem_ty)); let ptr_id = ptr.def(self); - // HACK(eddyb) updated pointee type of `ptr` post-`pointercast`. - let original_pointee_ty = ty; - - // --- Attempt GEP Merging Path --- - - // Check if the base pointer `ptr` itself was the result of a previous - // AccessChain instruction. Merging is only attempted if the input type `ty` - // matches the pointer's actual underlying pointee type `original_pointee_ty`. - // If they differ, merging could be invalid. - // HACK(eddyb) always attempted now, because we `pointercast` first, which: - // - is noop when `ty == original_pointee_ty` pre-`pointercast` (old condition) - // - may generate (potentially mergeable) new `AccessChain`s in other cases - let maybe_original_access_chain = if ty == original_pointee_ty { - // Search the current function's instructions... - // FIXME(eddyb) this could get ridiculously expensive, at the very least - // it could use `.rev()`, hoping the base pointer was recently defined? - let search_result = { - let emit = self.emit(); - let module = emit.module_ref(); - emit.selected_function().and_then(|func_idx| { - module.functions.get(func_idx).and_then(|func| { - // Find the instruction that defined our base pointer `ptr_id`. - func.all_inst_iter() - .find(|inst| inst.result_id == Some(ptr_id)) - .and_then(|ptr_def_inst| { - // Check if that instruction was an `AccessChain` or `InBoundsAccessChain`. - if matches!( - ptr_def_inst.class.opcode, - Op::AccessChain | Op::InBoundsAccessChain - ) { - // If yes, extract its base pointer and its indices. - let base_ptr = ptr_def_inst.operands[0].unwrap_id_ref(); - let indices = ptr_def_inst.operands[1..] - .iter() - .map(|op| op.unwrap_id_ref()) - .collect::>(); - Some((base_ptr, indices)) - } else { - // The instruction defining ptr was not an `AccessChain`. - None - } - }) - }) + + let maybe_original_access_chain = { + let emit = self.emit(); + let module = emit.module_ref(); + let current_func_blocks = emit + .selected_function() + .and_then(|func_idx| Some(&module.functions.get(func_idx)?.blocks[..])) + .unwrap_or_default(); + + // NOTE(eddyb) reverse search (`rfind`) used in the hopes that the + // instruction producing the value with ID `ptr_id` is more likely to + // have been added more recently to the function, even though there's + // still a risk of this causing a whole-function traversal. + // + // FIXME(eddyb) consider tracking that via e.g. `SpirvValueKind`. + current_func_blocks + .iter() + .flat_map(|b| &b.instructions) + .rfind(|inst| inst.result_id == Some(ptr_id)) + .filter(|inst| { + matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain) + }) + .map(|inst| { + let base_ptr = inst.operands[0].unwrap_id_ref(); + let indices = inst.operands[1..] + .iter() + .map(|op| op.unwrap_id_ref()) + .collect::>(); + (base_ptr, indices) }) - }; - search_result - } else { - // Input type `ty` doesn't match the pointer's actual type, cannot safely merge. - None }; + if let Some((original_ptr, original_indices)) = maybe_original_access_chain { + trace!("ptr_offset_strided: strategy 3 picked: merging access chains"); - // If we found that `ptr` was defined by a previous `AccessChain`... - if let Some((original_ptr, mut original_indices)) = maybe_original_access_chain { - trace!("has original access chain, attempting to merge GEPs"); + let mut merged_indices = original_indices; - // Check if merging is possible. Requires: - // 1. The original AccessChain had at least one index. - // 2. The *last* index of the original AccessChain is a constant. - // 3. The *first* index (`ptr_base_index`) of the *current* GEP is a constant. - // Merging usually involves adding these two constant indices. - // - // FIXME(eddyb) the above comment seems inaccurate, there is no reason - // why runtime indices couldn't be added together just like constants - // (and in fact this is needed nowadays for all array indexing). - let can_merge = if let Some(&last_original_idx_id) = original_indices.last() { - // HACK(eddyb) see the above comment, this bypasses the const - // check below, without tripping a clippy warning etc. - let always_merge = true; - always_merge || { - // Check if both the last original index and the current base index are constant scalars. - self.builder - .lookup_const_scalar(last_original_idx_id.with_type(ptr_base_index.ty)) - .is_some() - && self.builder.lookup_const_scalar(ptr_base_index).is_some() - } - } else { - // Original access chain had no indices to merge with. - false - }; - - if can_merge { - let last_original_idx_id = original_indices.last_mut().unwrap(); - // Add the current `ptr_base_index` to the last index of the original chain. - // The result becomes the new last index. - *last_original_idx_id = self - .add( - // Ensure types match for add. - last_original_idx_id.with_type(ptr_base_index.ty), - ptr_base_index, - ) - // Define the result of the addition. - .def(self); - // Append the remaining indices (`indices`) from the current GEP operation. - original_indices.extend(gep_indices_ids); + let last_index_id = merged_indices.last_mut().unwrap(); + *last_index_id = self.add(last_index_id.with_type(index.ty), index).def(self); - trace!( - "emitting merged access chain with pointer to type: {}", - self.debug_type(calculated_pointee_type) - ); - // Emit a *single* AccessChain using the *original* base pointer and the *merged* index list. - // The result type *must* be the `final_spirv_ptr_type` calculated earlier based on the full chain of operations. - return self.emit_access_chain( - final_spirv_ptr_type, // Use the strictly calculated final type. - original_ptr, // Base pointer from the *original* AccessChain. - None, // No separate base index; it's merged. - original_indices, // The combined list of indices. - is_inbounds, // Preserve original inbounds request. - ); - } else { - // Cannot merge because one or both relevant indices are not constant, - // or the original chain was empty. - trace!( - "Last index or base offset is not constant, or no last index, cannot merge." - ); - } - } else { - // The base pointer `ptr` was not the result of an AccessChain, or merging - // wasn't attempted due to type mismatch. - trace!("no original access chain to merge with"); + return self.emit_access_chain(ptr.ty, original_ptr, None, merged_indices, is_inbounds); } - // --- End GEP Merging Path --- - - // --- Fallback / Default Path --- - // This path is taken if neither the Recovery nor the Merging path succeeded or applied. - // It performs a more direct translation of the GEP request. + // None of the legalizing strategies above applied, so this operation + // isn't really supported (and will error if actually used from a shader). + // + // FIXME(eddyb) supersede via SPIR-T pointer legalization (e.g. `qptr`). + trace!("ptr_offset_strided: falling back to (illegal) `OpPtrAccessChain`"); - trace!( - "emitting access chain via fallback path with pointer type: {}", - self.debug_type(final_spirv_ptr_type) + let result_ptr = if is_inbounds { + self.emit() + .in_bounds_ptr_access_chain(ptr.ty, None, ptr_id, index.def(self), vec![]) + } else { + self.emit() + .ptr_access_chain(ptr.ty, None, ptr_id, index.def(self), vec![]) + } + .unwrap(); + self.zombie( + result_ptr, + "cannot offset a pointer to an arbitrary element", ); - // Emit the `AccessChain` instruction. - self.emit_access_chain( - final_spirv_ptr_type, // Result *must* be a pointer to the final calculated type. - ptr_id, // Use the (potentially casted) base pointer ID. - Some(ptr_base_index), // Provide the first index separately. - gep_indices_ids, // Provide the rest of the indices. - is_inbounds, // Preserve original inbounds request. - ) + result_ptr.with_type(ptr.ty) } #[instrument( From 35ad4b2ea5742625d81954a8e58cff070ae139ca Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Mon, 7 Jul 2025 10:27:18 +0300 Subject: [PATCH 6/6] WIP: builder: castless noop in `ptr_offset_strided`. --- .../src/builder/builder_methods.rs | 47 ++++++++++++++++--- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index a19b462440..43851a242a 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -759,8 +759,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { if const_offset == Some(Size::ZERO) { trace!("ptr_offset_strided: strategy 1 picked: offset 0 => pointer cast"); - // FIXME(eddyb) could this just `return ptr;`? what even breaks? - return self.pointercast(ptr, self.type_ptr_to(stride_elem_ty)); + // FIXME(eddyb) replace docs to remove mentions of pointer casting. + return ptr; } // Strategy 2: try recovering an `OpAccessChain` from a constant offset. @@ -3261,6 +3261,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { Bitcast(ID, ID), CompositeExtract(ID, ID, u32), InBoundsAccessChain(ID, ID, u32), + InBoundsAccessChain2(ID, ID, u32, u32), Store(ID, ID), Load(ID, ID), CopyMemory(ID, ID), @@ -3317,6 +3318,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { Inst::Unsupported(inst.class.opcode) } } + (Op::InBoundsAccessChain, Some(r), &[p, i, j]) => { + if let [Some(i), Some(j)] = [i, j].map(const_as_u32) { + Inst::InBoundsAccessChain2(r, p, i, j) + } else { + Inst::Unsupported(inst.class.opcode) + } + } (Op::Store, None, &[p, v]) => Inst::Store(p, v), (Op::Load, Some(r), &[p]) => Inst::Load(r, p), (Op::CopyMemory, None, &[a, b]) => Inst::CopyMemory(a, b), @@ -3508,20 +3516,45 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let rev_copies_to_rt_args_array_src_ptrs: SmallVec<[_; 4]> = (0..rt_args_count).rev().map(|rt_arg_idx| { - let copy_to_rt_args_array_insts = try_rev_take(4).ok_or_else(|| { + let mut copy_to_rt_args_array_insts = try_rev_take(3).ok_or_else(|| { FormatArgsNotRecognized( "[fmt::rt::Argument; N] copy: ran out of instructions".into(), ) })?; + + // HACK(eddyb) account for both the split and combined + // access chain cases that `inbounds_gep` can now cause. + if let Inst::InBoundsAccessChain(dst_field_ptr, dst_base_ptr, 0) = + copy_to_rt_args_array_insts[0] + { + if let Some(mut prev_insts) = try_rev_take(1) { + assert_eq!(prev_insts.len(), 1); + let prev_inst = prev_insts.pop().unwrap(); + + match prev_inst { + Inst::InBoundsAccessChain( + array_elem_ptr, + array_ptr, + idx, + ) if dst_base_ptr == array_elem_ptr => { + copy_to_rt_args_array_insts[0] = + Inst::InBoundsAccessChain2(dst_field_ptr, array_ptr, idx, 0); + } + _ => { + // HACK(eddyb) don't lose the taken `prev_inst`. + copy_to_rt_args_array_insts.insert(0, prev_inst); + } + } + } + } + match copy_to_rt_args_array_insts[..] { [ - Inst::InBoundsAccessChain(array_slot, array_base, array_idx), - Inst::InBoundsAccessChain(dst_field_ptr, dst_base_ptr, 0), + Inst::InBoundsAccessChain2(dst_field_ptr, dst_array_base_ptr, array_idx, 0), Inst::InBoundsAccessChain(src_field_ptr, src_base_ptr, 0), Inst::CopyMemory(copy_dst, copy_src), - ] if array_base == rt_args_array_ptr_id + ] if dst_array_base_ptr == rt_args_array_ptr_id && array_idx as usize == rt_arg_idx - && dst_base_ptr == array_slot && (copy_dst, copy_src) == (dst_field_ptr, src_field_ptr) => { Ok(src_base_ptr)