diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 5521aa973c..43851a242a 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -620,13 +620,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { && leaf_size_range.contains(&ty_size) && leaf_ty.map_or(true, |leaf_ty| leaf_ty == ty) { - trace!("returning type: {:?}", self.debug_type(ty)); - trace!("returning indices with len: {:?}", indices.len()); + trace!("successful recovery leaf type: {:?}", self.debug_type(ty)); + trace!("successful recovery indices: {:?}", indices); return Some((indices, ty)); } } } + // NOTE(eddyb) see `ptr_offset_strided`, which this now forwards to. #[instrument(level = "trace", skip(self), fields(ty = ?self.debug_type(ty), ptr, combined_indices = ?combined_indices.iter().map(|x| (self.debug_type(x.ty), x.kind)).collect::>(), is_inbounds))] fn maybe_inbounds_gep( &mut self, @@ -655,344 +656,220 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // https://github.com/gpuweb/gpuweb/issues/33 let (&ptr_base_index, indices) = combined_indices.split_first().unwrap(); - // Determine if this GEP operation is effectively byte-level addressing. - // This check is based on the *provided* input type `ty`. If `ty` is i8 or u8, - // it suggests the caller intends to perform byte-offset calculations, - // which might allow for more flexible type recovery later. - let is_byte_gep = matches!(self.lookup_type(ty), SpirvType::Integer(8, _)); - trace!("Is byte GEP (based on input type): {}", is_byte_gep); - - // --- Calculate the final pointee type based on the GEP operation --- - - // This loop does the type traversal according to the `indices` (excluding the - // base offset index). It starts with the initial element type `ty` and - // iteratively applies each index to determine the type of the element being - // accessed at each step. The result is the type that the *final* pointer, - // generated by the SPIR-V `AccessChain`` instruction, *must* point to according - // to the SPIR-V specification and the provided `indices`. - let mut calculated_pointee_type = ty; - for index_val in indices { - // Lookup the current aggregate type we are indexing into. - calculated_pointee_type = match self.lookup_type(calculated_pointee_type) { - // If it's a struct (ADT), the index must be a constant. Use it to get - // the field type. - SpirvType::Adt { field_types, .. } => { - let const_index = self - .builder - .lookup_const_scalar(*index_val) - .expect("Non-constant struct index for GEP") - as usize; - // Get the type of the specific field. - field_types[const_index] - } - // If it's an array, vector, or matrix, indexing yields the element type. - SpirvType::Array { element, .. } - | SpirvType::RuntimeArray { element } - | SpirvType::Vector { element, .. } - | SpirvType::Matrix { element, .. } => element, - // Special case: If we started with a byte GEP (`is_byte_gep` is true) and - // we are currently indexing into a byte type, the result is still a byte type. - // This prevents errors if `indices` contains non-zero values when `ty` is u8/i8. - SpirvType::Integer(8, signedness) if is_byte_gep => { - // Define the resulting byte type as it might not exist yet). - SpirvType::Integer(8, signedness).def(self.span(), self) - } - // Any other type cannot be indexed into via GEP. - _ => self.fatal(format!( - "GEP not implemented for indexing into type {}", - self.debug_type(calculated_pointee_type) - )), + let indexed_base_ptr = self.ptr_offset_strided(ptr, ty, ptr_base_index, is_inbounds); + + // HACK(eddyb) this effectively removes any real support for GEPs with + // any `indices` (beyond `ptr_base_index`), which should now be the case + // across `rustc_codegen_ssa` (see also comment inside `inbounds_gep`). + // FIXME(eddyb) are the warning + fallback path even work keeping? + if !indices.is_empty() { + // HACK(eddyb) Cargo silences warnings in dependencies. + let force_warn = |span, msg| -> rustc_errors::Diag<'_, ()> { + rustc_errors::Diag::new( + self.tcx.dcx(), + rustc_errors::Level::ForceWarning(None), + msg, + ) + .with_span(span) }; + force_warn( + self.span(), + format!( + "[RUST-GPU BUG] `inbounds_gep` or `gep` called with \ + {} combined indices (expected only 1)", + combined_indices.len(), + ), + ) + .emit(); + + let indexed_base_ptr_id = indexed_base_ptr.def(self); + assert_ty_eq!(self, indexed_base_ptr.ty, self.type_ptr_to(ty)); + + let mut final_pointee = ty; + for &index in indices { + final_pointee = match self.lookup_type(final_pointee) { + SpirvType::Adt { field_types, .. } => { + field_types[self + .builder + .lookup_const_scalar(index) + .expect("non-const struct index for GEP") + as usize] + } + SpirvType::Array { element, .. } + | SpirvType::RuntimeArray { element } + | SpirvType::Vector { element, .. } + | SpirvType::Matrix { element, .. } => element, + + _ => self.fatal(format!( + "GEP not implemented for indexing into type {}", + self.debug_type(final_pointee) + )), + }; + } + let final_spirv_ptr_type = self.type_ptr_to(final_pointee); + + let indices_ids: Vec<_> = indices.iter().map(|index| index.def(self)).collect(); + + return self.emit_access_chain( + final_spirv_ptr_type, + indexed_base_ptr_id, + None, + indices_ids, + is_inbounds, + ); } - // Construct the SPIR-V pointer type that points to the final calculated pointee - // type. This is the *required* result type for the SPIR-V `AccessChain` - // instruction. - let final_spirv_ptr_type = self.type_ptr_to(calculated_pointee_type); - trace!( - "Calculated final SPIR-V pointee type: {}", - self.debug_type(calculated_pointee_type) - ); - trace!( - "Calculated final SPIR-V ptr type: {}", - self.debug_type(final_spirv_ptr_type) - ); - // Ensure all the `indices` (excluding the base offset index) are defined in the - // SPIR-V module and get their corresponding SPIR-V IDs. These IDs will be used - // as operands in the AccessChain instruction. - let gep_indices_ids: Vec<_> = indices.iter().map(|index| index.def(self)).collect(); + indexed_base_ptr + } - // --- Prepare the base pointer --- + /// Array-indexing-like pointer arithmetic, i.e. `(ptr: *T).offset(index)` + /// (or `wrapping_offset` instead of `offset`, for `is_inbounds = false`), + /// where `T` is given by `stride_elem_ty` (named so for extra clarity). + /// + /// This can produce legal SPIR-V by using 3 strategies: + /// 1. `pointercast` for a constant offset of `0` + /// - itself can succeed via `recover_access_chain_from_offset` + /// - even if a specific cast is unsupported, legal SPIR-V can still be + /// obtained, if all downstream uses rely on e.g. `strip_ptrcasts` + /// - also used before the merge strategy (3.), to improve its chances + /// 2. `recover_access_chain_from_offset` for constant offsets + /// (e.g. from `ptradd`/`inbounds_ptradd` used to access `struct` fields) + /// 3. merging onto an array `OpAccessChain` with the same `stride_elem_ty` + /// (possibly `&array[0]` from `pointercast` doing `*[T; N]` -> `*T`) + #[instrument(level = "trace", skip(self), fields(ptr, stride_elem_ty = ?self.debug_type(stride_elem_ty), index, is_inbounds))] + fn ptr_offset_strided( + &mut self, + ptr: SpirvValue, + stride_elem_ty: Word, + index: SpirvValue, + is_inbounds: bool, + ) -> SpirvValue { + // Precompute a constant `index * stride` (i.e. effective pointer offset) + // if possible, as both strategies 1 and 2 rely on knowing this value. + let const_offset = self.builder.lookup_const_scalar(index).and_then(|idx| { + let idx_u64 = u64::try_from(idx).ok()?; + let stride = self.lookup_type(stride_elem_ty).sizeof(self)?; + Some(idx_u64 * stride) + }); - // Remove any potentially redundant pointer casts applied to the input `ptr`. - // GEP operations should ideally work on the "underlying" pointer. - let ptr = ptr.strip_ptrcasts(); - // Get the SPIR-V ID for the (potentially stripped) base pointer. - let ptr_id = ptr.def(self); - // Determine the actual pointee type of the base pointer `ptr` *after* stripping casts. - // This might differ from the input `ty` if `ty` was less specific (e.g., u8). - let original_pointee_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { pointee } => pointee, - other => self.fatal(format!("gep called on non-pointer type: {other:?}")), - }; + // Strategy 1: an offset of `0` is always a noop, and `pointercast` + // gets to use `SpirvValueKind::LogicalPtrCast`, which can later + // be "undone" (by `strip_ptrcasts`), allowing more flexibility + // downstream (instead of overeagerly "shrinking" the pointee). + if const_offset == Some(Size::ZERO) { + trace!("ptr_offset_strided: strategy 1 picked: offset 0 => pointer cast"); - // --- Recovery Path --- + // FIXME(eddyb) replace docs to remove mentions of pointer casting. + return ptr; + } - // Try to calculate the byte offset implied by the *first* index - // (`ptr_base_index`) if it's a compile-time constant. This uses the size of the - // *input type* `ty`. - let const_ptr_offset_bytes = self - .builder - .lookup_const_scalar(ptr_base_index) // Check if ptr_base_index is constant scalar - .and_then(|idx| { - let idx_u64 = u64::try_from(idx).ok()?; - // Get the size of the input type `ty` - self.lookup_type(ty) - .sizeof(self) - // Calculate offset in bytes - .map(|size| idx_u64.saturating_mul(size.bytes())) - }); - - // If we successfully calculated a constant byte offset for the first index... - if let Some(const_ptr_offset_bytes) = const_ptr_offset_bytes { - // Try to reconstruct a more "structured" access chain based on the *original* - // pointee type of the pointer (`original_pointee_ty`) and the calculated byte offset. - // This is useful if the input `ty` was generic (like u8) but the pointer actually - // points to a structured type (like a struct). `recover_access_chain_from_offset` - // attempts to find a sequence of constant indices (`base_indices`) into - // `original_pointee_ty` that matches the `const_ptr_offset_bytes`. - if let Some((base_indices, base_pointee_ty)) = self.recover_access_chain_from_offset( - // Start from the pointer's actual underlying type + // Strategy 2: try recovering an `OpAccessChain` from a constant offset. + if let Some(const_offset) = const_offset { + // Remove any (redundant) pointer casts applied to the input `ptr`, + // to obtain the "most original" pointer (which ideally will be e.g. + // a whole `OpVariable`, or the result of a previous `OpAccessChain`). + let original_ptr = ptr.strip_ptrcasts(); + let original_pointee_ty = match self.lookup_type(original_ptr.ty) { + SpirvType::Pointer { pointee } => pointee, + other => self.fatal(format!("pointer arithmetic on non-pointer type {other:?}")), + }; + + if let Some((const_indices, leaf_pointee_ty)) = self.recover_access_chain_from_offset( original_pointee_ty, - // The target byte offset - Size::from_bytes(const_ptr_offset_bytes), - // Allowed range (not strictly needed here?) + const_offset, Some(Size::ZERO)..=None, - // Don't require alignment None, ) { - // Recovery successful! Found a structured path (`base_indices`) to the target offset. trace!( - "`recover_access_chain_from_offset` returned Some with base_pointee_ty: {}", - self.debug_type(base_pointee_ty) + "ptr_offset_strided: strategy 2 picked: offset {const_offset:?} \ + => access chain w/ {const_indices:?}" ); - // Determine the result type for the `AccessChain` instruction we might - // emit. By default, use the `final_spirv_ptr_type` strictly calculated - // earlier from `ty` and `indices`. - // - // If this is a byte GEP *and* the recovered type happens to be a byte - // type, we can use the pointer type derived from the *recovered* type - // (`base_pointee_ty`). This helps preserve type information when - // recovery works for byte addressing. - let result_wrapper_type = if !is_byte_gep - || matches!(self.lookup_type(base_pointee_ty), SpirvType::Integer(8, _)) - { - trace!( - "Using strictly calculated type for wrapper: {}", - // Use type based on input `ty` + `indices` - self.debug_type(calculated_pointee_type) - ); - final_spirv_ptr_type - } else { - trace!( - "Byte GEP allowing recovered type for wrapper: {}", - // Use type based on recovery result - self.debug_type(base_pointee_ty) - ); - self.type_ptr_to(base_pointee_ty) - }; - - // Check if we can directly use the recovered path combined with the - // remaining indices. This is possible if: - // 1. The input type `ty` matches the type found by recovery - // (`base_pointee_ty`). This means the recovery didn't fundamentally - // change the type interpretation needed for the *next* steps - // (`indices`). - // OR - // 2. There are no further indices (`gep_indices_ids` is empty). In this - // case, the recovery path already leads to the final destination. - if ty == base_pointee_ty || gep_indices_ids.is_empty() { - // Combine the recovered constant indices with the remaining dynamic/constant indices. - let combined_indices = base_indices - .into_iter() - // Convert recovered `u32` indices to constant SPIR-V IDs. - .map(|idx| self.constant_u32(self.span(), idx).def(self)) - // Chain the original subsequent indices (`indices`). - .chain(gep_indices_ids.iter().copied()) - .collect(); + let leaf_ptr_ty = self.type_ptr_to(leaf_pointee_ty); + let original_ptr_id = original_ptr.def(self); + let const_indices_ids = const_indices + .into_iter() + .map(|idx| self.constant_u32(self.span(), idx).def(self)) + .collect(); - trace!( - "emitting access chain via recovery path with wrapper type: {}", - self.debug_type(result_wrapper_type) - ); - // Emit a single AccessChain using the original pointer `ptr_id` and the fully combined index list. - // Note: We don't pass `ptr_base_index` here because its effect is incorporated into `base_indices`. - return self.emit_access_chain( - result_wrapper_type, // The chosen result pointer type - ptr_id, // The original base pointer ID - None, // No separate base index needed - combined_indices, // The combined structured + original indices - is_inbounds, // Preserve original inbounds request - ); - } else { - // Recovery happened, but the recovered type `base_pointee_ty` doesn't match the input `ty`, - // AND there are more `indices` to process. Using the `base_indices` derived from - // `original_pointee_ty` would be incorrect for interpreting the subsequent `indices` - // which were intended to operate relative to `ty`. Fall back to the standard path. - trace!( - "Recovery type mismatch ({}) vs ({}) and GEP indices exist, falling back", - self.debug_type(ty), - self.debug_type(base_pointee_ty) - ); - } - } else { - // `recover_access_chain_from_offset` couldn't find a structured path for the constant offset. - trace!("`recover_access_chain_from_offset` returned None, falling back"); + return self.emit_access_chain( + leaf_ptr_ty, + original_ptr_id, + None, + const_indices_ids, + is_inbounds, + ); } } - // --- End Recovery Path --- - - // --- Attempt GEP Merging Path --- + // Strategy 3: try merging onto an `OpAccessChain` of matching type, + // and this starts by `pointercast`-ing to that type for two reasons: + // - this is the only type a merged `OpAccessChain` could produce + // - `pointercast` can itself produce a new `OpAccessChain` in the + // right circumstances (e.g. `&array[0]` for `*[T; N]` -> `*T`) + let ptr = self.pointercast(ptr, self.type_ptr_to(stride_elem_ty)); + let ptr_id = ptr.def(self); - // Check if the base pointer `ptr` itself was the result of a previous - // AccessChain instruction. Merging is only attempted if the input type `ty` - // matches the pointer's actual underlying pointee type `original_pointee_ty`. - // If they differ, merging could be invalid. - let maybe_original_access_chain = if ty == original_pointee_ty { - // Search the current function's instructions... - // FIXME(eddyb) this could get ridiculously expensive, at the very least - // it could use `.rev()`, hoping the base pointer was recently defined? - let search_result = { - let emit = self.emit(); - let module = emit.module_ref(); - emit.selected_function().and_then(|func_idx| { - module.functions.get(func_idx).and_then(|func| { - // Find the instruction that defined our base pointer `ptr_id`. - func.all_inst_iter() - .find(|inst| inst.result_id == Some(ptr_id)) - .and_then(|ptr_def_inst| { - // Check if that instruction was an `AccessChain` or `InBoundsAccessChain`. - if matches!( - ptr_def_inst.class.opcode, - Op::AccessChain | Op::InBoundsAccessChain - ) { - // If yes, extract its base pointer and its indices. - let base_ptr = ptr_def_inst.operands[0].unwrap_id_ref(); - let indices = ptr_def_inst.operands[1..] - .iter() - .map(|op| op.unwrap_id_ref()) - .collect::>(); - Some((base_ptr, indices)) - } else { - // The instruction defining ptr was not an `AccessChain`. - None - } - }) - }) + let maybe_original_access_chain = { + let emit = self.emit(); + let module = emit.module_ref(); + let current_func_blocks = emit + .selected_function() + .and_then(|func_idx| Some(&module.functions.get(func_idx)?.blocks[..])) + .unwrap_or_default(); + + // NOTE(eddyb) reverse search (`rfind`) used in the hopes that the + // instruction producing the value with ID `ptr_id` is more likely to + // have been added more recently to the function, even though there's + // still a risk of this causing a whole-function traversal. + // + // FIXME(eddyb) consider tracking that via e.g. `SpirvValueKind`. + current_func_blocks + .iter() + .flat_map(|b| &b.instructions) + .rfind(|inst| inst.result_id == Some(ptr_id)) + .filter(|inst| { + matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain) + }) + .map(|inst| { + let base_ptr = inst.operands[0].unwrap_id_ref(); + let indices = inst.operands[1..] + .iter() + .map(|op| op.unwrap_id_ref()) + .collect::>(); + (base_ptr, indices) }) - }; - search_result - } else { - // Input type `ty` doesn't match the pointer's actual type, cannot safely merge. - None }; + if let Some((original_ptr, original_indices)) = maybe_original_access_chain { + trace!("ptr_offset_strided: strategy 3 picked: merging access chains"); - // If we found that `ptr` was defined by a previous `AccessChain`... - if let Some((original_ptr, mut original_indices)) = maybe_original_access_chain { - trace!("has original access chain, attempting to merge GEPs"); - - // Check if merging is possible. Requires: - // 1. The original AccessChain had at least one index. - // 2. The *last* index of the original AccessChain is a constant. - // 3. The *first* index (`ptr_base_index`) of the *current* GEP is a constant. - // Merging usually involves adding these two constant indices. - let can_merge = if let Some(&last_original_idx_id) = original_indices.last() { - // Check if both the last original index and the current base index are constant scalars. - self.builder - .lookup_const_scalar(last_original_idx_id.with_type(ptr_base_index.ty)) - .is_some() - && self.builder.lookup_const_scalar(ptr_base_index).is_some() - } else { - // Original access chain had no indices to merge with. - false - }; + let mut merged_indices = original_indices; - if can_merge { - let last_original_idx_id = original_indices.last_mut().unwrap(); - // Add the current `ptr_base_index` to the last index of the original chain. - // The result becomes the new last index. - *last_original_idx_id = self - .add( - // Ensure types match for add. - last_original_idx_id.with_type(ptr_base_index.ty), - ptr_base_index, - ) - // Define the result of the addition. - .def(self); - // Append the remaining indices (`indices`) from the current GEP operation. - original_indices.extend(gep_indices_ids); + let last_index_id = merged_indices.last_mut().unwrap(); + *last_index_id = self.add(last_index_id.with_type(index.ty), index).def(self); - trace!( - "emitting merged access chain with pointer to type: {}", - self.debug_type(calculated_pointee_type) - ); - // Emit a *single* AccessChain using the *original* base pointer and the *merged* index list. - // The result type *must* be the `final_spirv_ptr_type` calculated earlier based on the full chain of operations. - return self.emit_access_chain( - final_spirv_ptr_type, // Use the strictly calculated final type. - original_ptr, // Base pointer from the *original* AccessChain. - None, // No separate base index; it's merged. - original_indices, // The combined list of indices. - is_inbounds, // Preserve original inbounds request. - ); - } else { - // Cannot merge because one or both relevant indices are not constant, - // or the original chain was empty. - trace!( - "Last index or base offset is not constant, or no last index, cannot merge." - ); - } - } else { - // The base pointer `ptr` was not the result of an AccessChain, or merging - // wasn't attempted due to type mismatch. - trace!("no original access chain to merge with"); - } - // --- End GEP Merging Path --- - - // --- Fallback / Default Path --- - - // This path is taken if neither the Recovery nor the Merging path succeeded or applied. - // It performs a more direct translation of the GEP request. - - // HACK(eddyb): Workaround for potential upstream issues where pointers might lack precise type info. - // FIXME(eddyb): Ideally, this should use untyped memory features if available/necessary. - - // Before emitting the AccessChain, explicitly cast the base pointer `ptr` to - // ensure its pointee type matches the input `ty`. This is required because the - // SPIR-V `AccessChain` instruction implicitly uses the size of the base - // pointer's pointee type when applying the *first* index operand (our - // `ptr_base_index`). If `ty` and `original_pointee_ty` mismatched and we - // reached this fallback, this cast ensures SPIR-V validity. - trace!("maybe_inbounds_gep fallback path calling pointercast"); - // Cast ptr to point to `ty`. - let ptr = self.pointercast(ptr, self.type_ptr_to(ty)); - // Get the ID of the (potentially newly casted) pointer. - let ptr_id = ptr.def(self); + return self.emit_access_chain(ptr.ty, original_ptr, None, merged_indices, is_inbounds); + } - trace!( - "emitting access chain via fallback path with pointer type: {}", - self.debug_type(final_spirv_ptr_type) + // None of the legalizing strategies above applied, so this operation + // isn't really supported (and will error if actually used from a shader). + // + // FIXME(eddyb) supersede via SPIR-T pointer legalization (e.g. `qptr`). + trace!("ptr_offset_strided: falling back to (illegal) `OpPtrAccessChain`"); + + let result_ptr = if is_inbounds { + self.emit() + .in_bounds_ptr_access_chain(ptr.ty, None, ptr_id, index.def(self), vec![]) + } else { + self.emit() + .ptr_access_chain(ptr.ty, None, ptr_id, index.def(self), vec![]) + } + .unwrap(); + self.zombie( + result_ptr, + "cannot offset a pointer to an arbitrary element", ); - // Emit the `AccessChain` instruction. - self.emit_access_chain( - final_spirv_ptr_type, // Result *must* be a pointer to the final calculated type. - ptr_id, // Use the (potentially casted) base pointer ID. - Some(ptr_base_index), // Provide the first index separately. - gep_indices_ids, // Provide the rest of the indices. - is_inbounds, // Preserve original inbounds request. - ) + result_ptr.with_type(ptr.ty) } #[instrument( @@ -1956,6 +1833,21 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { ptr: Self::Value, indices: &[Self::Value], ) -> Self::Value { + // HACK(eddyb) effectively a backport of this `gep [0, i]` -> `gep [i]` + // PR: https://github.com/rust-lang/rust/pull/134117 to even earlier + // nightlies - and that PR happens to remove the last GEP that can be + // emitted with any "structured" (struct/array) indices, beyond the + // "first index" (which acts as `<*T>::offset` aka "pointer arithmetic"). + if let &[ptr_base_index, structured_index] = indices { + if self.builder.lookup_const_scalar(ptr_base_index) == Some(0) { + if let SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element, .. } = + self.lookup_type(ty) + { + return self.maybe_inbounds_gep(element, ptr, &[structured_index], true); + } + } + } + self.maybe_inbounds_gep(ty, ptr, indices, true) } @@ -3369,6 +3261,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { Bitcast(ID, ID), CompositeExtract(ID, ID, u32), InBoundsAccessChain(ID, ID, u32), + InBoundsAccessChain2(ID, ID, u32, u32), Store(ID, ID), Load(ID, ID), CopyMemory(ID, ID), @@ -3425,6 +3318,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { Inst::Unsupported(inst.class.opcode) } } + (Op::InBoundsAccessChain, Some(r), &[p, i, j]) => { + if let [Some(i), Some(j)] = [i, j].map(const_as_u32) { + Inst::InBoundsAccessChain2(r, p, i, j) + } else { + Inst::Unsupported(inst.class.opcode) + } + } (Op::Store, None, &[p, v]) => Inst::Store(p, v), (Op::Load, Some(r), &[p]) => Inst::Load(r, p), (Op::CopyMemory, None, &[a, b]) => Inst::CopyMemory(a, b), @@ -3616,20 +3516,45 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let rev_copies_to_rt_args_array_src_ptrs: SmallVec<[_; 4]> = (0..rt_args_count).rev().map(|rt_arg_idx| { - let copy_to_rt_args_array_insts = try_rev_take(4).ok_or_else(|| { + let mut copy_to_rt_args_array_insts = try_rev_take(3).ok_or_else(|| { FormatArgsNotRecognized( "[fmt::rt::Argument; N] copy: ran out of instructions".into(), ) })?; + + // HACK(eddyb) account for both the split and combined + // access chain cases that `inbounds_gep` can now cause. + if let Inst::InBoundsAccessChain(dst_field_ptr, dst_base_ptr, 0) = + copy_to_rt_args_array_insts[0] + { + if let Some(mut prev_insts) = try_rev_take(1) { + assert_eq!(prev_insts.len(), 1); + let prev_inst = prev_insts.pop().unwrap(); + + match prev_inst { + Inst::InBoundsAccessChain( + array_elem_ptr, + array_ptr, + idx, + ) if dst_base_ptr == array_elem_ptr => { + copy_to_rt_args_array_insts[0] = + Inst::InBoundsAccessChain2(dst_field_ptr, array_ptr, idx, 0); + } + _ => { + // HACK(eddyb) don't lose the taken `prev_inst`. + copy_to_rt_args_array_insts.insert(0, prev_inst); + } + } + } + } + match copy_to_rt_args_array_insts[..] { [ - Inst::InBoundsAccessChain(array_slot, array_base, array_idx), - Inst::InBoundsAccessChain(dst_field_ptr, dst_base_ptr, 0), + Inst::InBoundsAccessChain2(dst_field_ptr, dst_array_base_ptr, array_idx, 0), Inst::InBoundsAccessChain(src_field_ptr, src_base_ptr, 0), Inst::CopyMemory(copy_dst, copy_src), - ] if array_base == rt_args_array_ptr_id + ] if dst_array_base_ptr == rt_args_array_ptr_id && array_idx as usize == rt_arg_idx - && dst_base_ptr == array_slot && (copy_dst, copy_src) == (dst_field_ptr, src_field_ptr) => { Ok(src_base_ptr) diff --git a/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr b/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr index d2066bab76..5a0a23a6fb 100644 --- a/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr +++ b/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr @@ -6,27 +6,45 @@ OpEntryPoint Fragment %2 "main" OpExecutionMode %2 OriginUpperLeft %3 = OpString "/n[Rust panicked at $DIR/panic_builtin_bounds_check.rs:25:5]/n index out of bounds: the len is %u but the index is %u/n in main()/n" %4 = OpString "$DIR/panic_builtin_bounds_check.rs" -%5 = OpTypeVoid -%6 = OpTypeFunction %5 -%7 = OpTypeBool +OpDecorate %5 ArrayStride 4 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 %8 = OpTypeInt 32 0 -%9 = OpConstant %8 5 -%10 = OpConstant %8 4 -%11 = OpUndef %8 -%2 = OpFunction %5 None %6 -%12 = OpLabel +%9 = OpConstant %8 4 +%5 = OpTypeArray %8 %9 +%10 = OpTypePointer Function %5 +%11 = OpConstant %8 0 +%12 = OpConstant %8 1 +%13 = OpConstant %8 2 +%14 = OpConstant %8 3 +%15 = OpTypeBool +%16 = OpConstant %8 5 +%17 = OpUndef %8 +%18 = OpTypePointer Function %8 +%2 = OpFunction %6 None %7 +%19 = OpLabel +OpLine %4 30 4 +%20 = OpVariable %10 Function +OpLine %4 30 23 +%21 = OpCompositeConstruct %5 %11 %12 %13 %14 OpLine %4 25 4 -%13 = OpULessThan %7 %9 %10 +OpStore %20 %21 +%22 = OpULessThan %15 %16 %9 OpNoLine -OpSelectionMerge %14 None -OpBranchConditional %13 %15 %16 -%15 = OpLabel -OpBranch %14 -%16 = OpLabel +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %25 +%24 = OpLabel +OpBranch %23 +%25 = OpLabel OpLine %4 25 4 -%17 = OpExtInst %5 %1 1 %3 %11 %9 +%26 = OpExtInst %6 %1 1 %3 %17 %16 OpNoLine OpReturn -%14 = OpLabel +%23 = OpLabel +OpLine %4 25 4 +%27 = OpIAdd %8 %11 %16 +%28 = OpInBoundsAccessChain %18 %20 %27 +%29 = OpLoad %8 %28 +OpNoLine OpReturn OpFunctionEnd diff --git a/tests/compiletests/ui/lang/issue-46.rs b/tests/compiletests/ui/lang/issue-46.rs index 99a0835a95..fc64dd01d7 100644 --- a/tests/compiletests/ui/lang/issue-46.rs +++ b/tests/compiletests/ui/lang/issue-46.rs @@ -2,7 +2,7 @@ use spirv_std::spirv; -#[derive(Default)] +#[derive(Copy, Clone, Default)] struct Foo { bar: bool, baz: [[u32; 2]; 1], @@ -13,3 +13,20 @@ pub fn main() { let x = [[1; 2]; 1]; let y = [Foo::default(); 1]; } + +// HACK(eddyb) future-proofing against `[expr; 1]` -> `[expr]` +// MIR optimization (https://github.com/rust-lang/rust/pull/135322). +fn force_repeat_one() -> ([[u32; 2]; ONE], [Foo; ONE]) { + ([[1; 2]; ONE], [Foo::default(); ONE]) +} + +#[spirv(fragment)] +pub fn main_future_proof( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut [[u32; 2]; 2], +) { + let (x, y) = force_repeat_one::<1>(); + + // NOTE(eddyb) further guard against optimizations by using `x` and `y`. + out[0] = x[0]; + out[1] = y[0].baz[0]; +}