Skip to content

Commit 1b2720a

Browse files
committed
builder: use the Rust types of asm! operands to protect against untyped pointers.
1 parent f583740 commit 1b2720a

File tree

1 file changed

+76
-2
lines changed

1 file changed

+76
-2
lines changed

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
33

44
use super::Builder;
5+
use crate::abi::ConvSpirvType;
56
use crate::builder_spirv::{BuilderCursor, SpirvValue};
67
use crate::codegen_cx::CodegenCx;
78
use crate::spirv_type::SpirvType;
@@ -12,13 +13,18 @@ use rspirv::spirv::{
1213
GroupOperation, ImageOperands, KernelProfilingInfo, LoopControl, MemoryAccess, MemorySemantics,
1314
Op, RayFlags, SelectionControl, StorageClass, Word,
1415
};
16+
use rustc_abi::{BackendRepr, Primitive};
1517
use rustc_ast::ast::{InlineAsmOptions, InlineAsmTemplatePiece};
18+
use rustc_codegen_ssa::mir::operand::OperandValue;
1619
use rustc_codegen_ssa::mir::place::PlaceRef;
17-
use rustc_codegen_ssa::traits::{AsmBuilderMethods, InlineAsmOperandRef};
20+
use rustc_codegen_ssa::traits::{
21+
AsmBuilderMethods, BackendTypes, BuilderMethods, InlineAsmOperandRef,
22+
};
1823
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
1924
use rustc_middle::{bug, ty::Instance};
2025
use rustc_span::{DUMMY_SP, Span};
2126
use rustc_target::asm::{InlineAsmRegClass, InlineAsmRegOrRegClass, SpirVInlineAsmRegClass};
27+
use smallvec::SmallVec;
2228

2329
pub struct InstructionTable {
2430
table: FxHashMap<&'static str, &'static rspirv::grammar::Instruction<'static>>,
@@ -33,6 +39,35 @@ impl InstructionTable {
3339
}
3440
}
3541

42+
// HACK(eddyb) `InlineAsmOperandRef` lacks `#[derive(Clone)]`
43+
fn inline_asm_operand_ref_clone<'tcx, B: BackendTypes + ?Sized>(
44+
operand: &InlineAsmOperandRef<'tcx, B>,
45+
) -> InlineAsmOperandRef<'tcx, B> {
46+
use InlineAsmOperandRef::*;
47+
48+
match operand {
49+
&In { reg, value } => In { reg, value },
50+
&Out { reg, late, place } => Out { reg, late, place },
51+
&InOut {
52+
reg,
53+
late,
54+
in_value,
55+
out_place,
56+
} => InOut {
57+
reg,
58+
late,
59+
in_value,
60+
out_place,
61+
},
62+
Const { string } => Const {
63+
string: string.clone(),
64+
},
65+
&SymFn { instance } => SymFn { instance },
66+
&SymStatic { def_id } => SymStatic { def_id },
67+
&Label { label } => Label { label },
68+
}
69+
}
70+
3671
impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
3772
/* Example asm and the template it compiles to:
3873
asm!(
@@ -70,6 +105,45 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
70105
if !unsupported_options.is_empty() {
71106
self.err(format!("asm flags not supported: {unsupported_options:?}"));
72107
}
108+
109+
// HACK(eddyb) get more accurate pointers types, for pointer operands,
110+
// from the Rust types available in their respective `OperandRef`s.
111+
let mut operands: SmallVec<[_; 8]> =
112+
operands.iter().map(inline_asm_operand_ref_clone).collect();
113+
for operand in &mut operands {
114+
let (in_value, out_place) = match operand {
115+
InlineAsmOperandRef::In { value, .. } => (Some(value), None),
116+
InlineAsmOperandRef::InOut {
117+
in_value,
118+
out_place,
119+
..
120+
} => (Some(in_value), out_place.as_mut()),
121+
InlineAsmOperandRef::Out { place, .. } => (None, place.as_mut()),
122+
123+
InlineAsmOperandRef::Const { .. }
124+
| InlineAsmOperandRef::SymFn { .. }
125+
| InlineAsmOperandRef::SymStatic { .. }
126+
| InlineAsmOperandRef::Label { .. } => (None, None),
127+
};
128+
129+
if let Some(in_value) = in_value {
130+
if let (BackendRepr::Scalar(scalar), OperandValue::Immediate(in_value_spv)) =
131+
(in_value.layout.backend_repr, &mut in_value.val)
132+
{
133+
if let Primitive::Pointer(_) = scalar.primitive() {
134+
let in_value_precise_type = in_value.layout.spirv_type(self.span(), self);
135+
*in_value_spv = self.pointercast(*in_value_spv, in_value_precise_type);
136+
}
137+
}
138+
}
139+
if let Some(out_place) = out_place {
140+
let out_place_precise_type = out_place.layout.spirv_type(self.span(), self);
141+
let out_place_precise_ptr_type = self.type_ptr_to(out_place_precise_type);
142+
out_place.val.llval =
143+
self.pointercast(out_place.val.llval, out_place_precise_ptr_type);
144+
}
145+
}
146+
73147
// vec of lines, and each line is vec of tokens
74148
let mut tokens = vec![vec![]];
75149
for piece in template {
@@ -131,7 +205,7 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
131205
let mut id_map = FxHashMap::default();
132206
let mut defined_ids = FxHashSet::default();
133207
let mut id_to_type_map = FxHashMap::default();
134-
for operand in operands {
208+
for operand in &operands {
135209
if let InlineAsmOperandRef::In { reg: _, value } = operand {
136210
let value = value.immediate();
137211
id_to_type_map.insert(value.def(self), value.ty);

0 commit comments

Comments
 (0)