Skip to content

Commit 9244046

Browse files
LegNeatoFirestar99
authored andcommitted
Optimize From::from calls with constant arguments
1 parent 8c0ec8e commit 9244046

File tree

7 files changed

+167
-13
lines changed

7 files changed

+167
-13
lines changed

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3193,6 +3193,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
31933193
.and_then(|def_id| self.buffer_store_intrinsics.borrow().get(&def_id).copied());
31943194
let is_panic_entry_point = instance_def_id
31953195
.is_some_and(|def_id| self.panic_entry_points.borrow().contains(&def_id));
3196+
let from_trait_impl =
3197+
instance_def_id.and_then(|def_id| self.from_trait_impls.borrow().get(&def_id).copied());
31963198

31973199
if let Some(libm_intrinsic) = libm_intrinsic {
31983200
let result = self.call_libm_intrinsic(libm_intrinsic, result_type, args);
@@ -3204,8 +3206,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
32043206
self.debug_type(result.ty),
32053207
);
32063208
}
3207-
result
3208-
} else if is_panic_entry_point {
3209+
return result;
3210+
}
3211+
3212+
if is_panic_entry_point {
32093213
// HACK(eddyb) Rust 2021 `panic!` always uses `format_args!`, even
32103214
// in the simple case that used to pass a `&str` constant, which
32113215
// would not remain reachable in the SPIR-V - but `format_args!` is
@@ -3678,24 +3682,75 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
36783682
// HACK(eddyb) redirect any possible panic call to an abort, to avoid
36793683
// needing to materialize `&core::panic::Location` or `format_args!`.
36803684
self.abort_with_kind_and_message_debug_printf("panic", message, debug_printf_args);
3681-
self.undef(result_type)
3682-
} else if let Some(mode) = buffer_load_intrinsic {
3683-
self.codegen_buffer_load_intrinsic(result_type, args, mode)
3684-
} else if let Some(mode) = buffer_store_intrinsic {
3685+
return self.undef(result_type);
3686+
}
3687+
3688+
if let Some(mode) = buffer_load_intrinsic {
3689+
return self.codegen_buffer_load_intrinsic(result_type, args, mode);
3690+
}
3691+
3692+
if let Some(mode) = buffer_store_intrinsic {
36853693
self.codegen_buffer_store_intrinsic(args, mode);
36863694

36873695
let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self);
3688-
SpirvValue {
3696+
return SpirvValue {
36893697
kind: SpirvValueKind::IllegalTypeUsed(void_ty),
36903698
ty: void_ty,
3699+
};
3700+
}
3701+
3702+
if let Some((source_ty, target_ty)) = from_trait_impl {
3703+
// Optimize From::from calls with constant arguments to avoid creating intermediate types.
3704+
if let [arg] = args {
3705+
if let Some(const_val) = self.builder.lookup_const_scalar(*arg) {
3706+
use rustc_middle::ty::{FloatTy, IntTy, UintTy};
3707+
3708+
let optimized_result = match (source_ty.kind(), target_ty.kind()) {
3709+
// Unsigned integer widening conversions
3710+
(
3711+
ty::Uint(UintTy::U8),
3712+
ty::Uint(UintTy::U16 | UintTy::U32 | UintTy::U64 | UintTy::U128),
3713+
)
3714+
| (
3715+
ty::Uint(UintTy::U16),
3716+
ty::Uint(UintTy::U32 | UintTy::U64 | UintTy::U128),
3717+
)
3718+
| (ty::Uint(UintTy::U32), ty::Uint(UintTy::U64 | UintTy::U128))
3719+
| (ty::Uint(UintTy::U64), ty::Uint(UintTy::U128))
3720+
// Signed integer widening conversions
3721+
| (
3722+
ty::Int(IntTy::I8),
3723+
ty::Int(IntTy::I16 | IntTy::I32 | IntTy::I64 | IntTy::I128),
3724+
)
3725+
| (ty::Int(IntTy::I16), ty::Int(IntTy::I32 | IntTy::I64 | IntTy::I128))
3726+
| (ty::Int(IntTy::I32), ty::Int(IntTy::I64 | IntTy::I128))
3727+
| (ty::Int(IntTy::I64), ty::Int(IntTy::I128)) => {
3728+
Some(self.constant_int(result_type, const_val))
3729+
}
3730+
3731+
// Float widening conversions: f32->f64
3732+
(ty::Float(FloatTy::F32), ty::Float(FloatTy::F64)) => {
3733+
let float_val = f32::from_bits(const_val as u32) as f64;
3734+
Some(self.constant_float(result_type, float_val))
3735+
}
3736+
3737+
// No optimization for narrowing conversions or unsupported types
3738+
_ => None,
3739+
};
3740+
3741+
if let Some(result) = optimized_result {
3742+
return result;
3743+
}
3744+
}
36913745
}
3692-
} else {
3693-
let args = args.iter().map(|arg| arg.def(self)).collect::<Vec<_>>();
3694-
self.emit()
3695-
.function_call(result_type, None, callee_val, args)
3696-
.unwrap()
3697-
.with_type(result_type)
36983746
}
3747+
3748+
// Default: emit a regular function call
3749+
let args = args.iter().map(|arg| arg.def(self)).collect::<Vec<_>>();
3750+
self.emit()
3751+
.function_call(result_type, None, callee_val, args)
3752+
.unwrap()
3753+
.with_type(result_type)
36993754
}
37003755

37013756
fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {

crates/rustc_codegen_spirv/src/codegen_cx/declare.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,30 @@ impl<'tcx> CodegenCx<'tcx> {
172172
}
173173
}
174174

175+
// Check if this is a From trait implementation
176+
if let Some(impl_def_id) = self.tcx.impl_of_method(def_id) {
177+
if let Some(trait_ref) = self.tcx.impl_trait_ref(impl_def_id) {
178+
let trait_def_id = trait_ref.skip_binder().def_id;
179+
180+
// Check if this is the From trait.
181+
let trait_path = self.tcx.def_path_str(trait_def_id);
182+
if matches!(
183+
trait_path.as_str(),
184+
"core::convert::From" | "std::convert::From"
185+
) {
186+
// Extract the source and target types from the trait substitutions
187+
let trait_args = trait_ref.skip_binder().args;
188+
if let (Some(target_ty), Some(source_ty)) =
189+
(trait_args.types().nth(0), trait_args.types().nth(1))
190+
{
191+
self.from_trait_impls
192+
.borrow_mut()
193+
.insert(def_id, (source_ty, target_ty));
194+
}
195+
}
196+
}
197+
}
198+
175199
if [
176200
self.tcx.lang_items().panic_fn(),
177201
self.tcx.lang_items().panic_fmt(),

crates/rustc_codegen_spirv/src/codegen_cx/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ pub struct CodegenCx<'tcx> {
8484
/// Intrinsic for storing a `<T>` into a `&[u32]`. The `PassMode` is the mode of the `<T>`.
8585
pub buffer_store_intrinsics: RefCell<FxHashMap<DefId, &'tcx PassMode>>,
8686

87+
/// Maps `DefId`s of `From::from` method implementations to their source and target types.
88+
/// Used to optimize constant conversions like `u32::from(42u8)` to avoid creating the source type.
89+
pub from_trait_impls: RefCell<FxHashMap<DefId, (Ty<'tcx>, Ty<'tcx>)>>,
90+
8791
/// Some runtimes (e.g. intel-compute-runtime) disallow atomics on i8 and i16, even though it's allowed by the spec.
8892
/// This enables/disables them.
8993
pub i8_i16_atomics_allowed: bool,
@@ -203,6 +207,7 @@ impl<'tcx> CodegenCx<'tcx> {
203207
fmt_rt_arg_new_fn_ids_to_ty_and_spec: Default::default(),
204208
buffer_load_intrinsics: Default::default(),
205209
buffer_store_intrinsics: Default::default(),
210+
from_trait_impls: Default::default(),
206211
i8_i16_atomics_allowed: false,
207212
codegen_args,
208213
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Test that constant integer from casts are optimized to avoid creating intermediate
2+
// types that would require additional capabilities (e.g., Int8 capability for u8).
3+
4+
// build-pass
5+
// compile-flags: -C llvm-args=--disassemble-globals
6+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
7+
// normalize-stderr-test "OpSource .*\n" -> ""
8+
// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> ""
9+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
10+
11+
use spirv_std::spirv;
12+
13+
const K: u8 = 42;
14+
15+
#[spirv(fragment)]
16+
pub fn main(output: &mut u32) {
17+
let position = 2u32;
18+
// This cast should be optimized to directly create a u32 constant with value 42,
19+
// avoiding the creation of a u8 type that would require Int8 capability
20+
let global_y_offset_bits = u32::from(K);
21+
*output = global_y_offset_bits;
22+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
OpCapability Shader
2+
OpCapability ShaderClockKHR
3+
OpExtension "SPV_KHR_shader_clock"
4+
OpMemoryModel Logical Simple
5+
OpEntryPoint Fragment %1 "main" %2
6+
OpExecutionMode %1 OriginUpperLeft
7+
%3 = OpString "$OPSTRING_FILENAME/const-from-cast.rs"
8+
OpName %2 "output"
9+
OpDecorate %2 Location 0
10+
%4 = OpTypeInt 32 0
11+
%5 = OpTypePointer Output %4
12+
%6 = OpTypeVoid
13+
%7 = OpTypeFunction %6
14+
%2 = OpVariable %5 Output
15+
%8 = OpConstant %4 42
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Test that u32::from(u64) fails to compile since From<u64> is not implemented for u32
2+
// This ensures our From trait optimization doesn't accidentally allow invalid conversions
3+
4+
// build-fail
5+
6+
use spirv_std::spirv;
7+
8+
const K: u64 = 42;
9+
10+
#[spirv(fragment)]
11+
pub fn main(output: &mut u32) {
12+
// This should fail to compile because From<u64> is not implemented for u32
13+
// (u64 to u32 is a narrowing conversion that could lose data)
14+
let value = u32::from(K);
15+
*output = value;
16+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
error[E0277]: the trait bound `u32: From<u64>` is not satisfied
2+
--> $DIR/u32-from-u64-fail.rs:14:17
3+
|
4+
14 | let value = u32::from(K);
5+
| ^^^ the trait `From<u64>` is not implemented for `u32`
6+
|
7+
= help: the following other types implement trait `From<T>`:
8+
`u32` implements `From<Char>`
9+
`u32` implements `From<Ipv4Addr>`
10+
`u32` implements `From<bool>`
11+
`u32` implements `From<char>`
12+
`u32` implements `From<u16>`
13+
`u32` implements `From<u8>`
14+
15+
error: aborting due to 1 previous error
16+
17+
For more information about this error, try `rustc --explain E0277`.

0 commit comments

Comments
 (0)