Skip to content

Commit 5976674

Browse files
committed
Refactor
1 parent 02b3234 commit 5976674

File tree

5 files changed

+76
-58
lines changed

5 files changed

+76
-58
lines changed

src/librustc_mir/interpret/intrinsics.rs

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -240,24 +240,55 @@ impl<'mir, 'tcx, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
240240
self.copy_op_transmute(args[0], dest)?;
241241
}
242242
"simd_insert" => {
243-
let mut vector = self.read_vector(args[0])?;
244-
let index = self.read_scalar(args[1])?.to_u32()? as usize;
243+
let index = self.read_scalar(args[1])?.to_u32()? as u64;
245244
let scalar = self.read_immediate(args[2])?;
246-
if vector[index].layout.size == scalar.layout.size {
247-
vector[index] = scalar;
248-
} else {
249-
throw_ub_format!(
250-
"Inserting `{}` with size `{}` to a vector element place of size `{}`",
251-
scalar.layout.ty,
252-
scalar.layout.size.bytes(), vector[index].layout.size.bytes()
253-
);
245+
let input = args[0];
246+
let (len, e_ty) = self.read_vector_ty(input);
247+
assert!(
248+
index < len,
249+
"index `{}` must be in bounds of vector type `{}`: `[0, {})`",
250+
index, e_ty, len
251+
);
252+
assert_eq!(
253+
args[0].layout, dest.layout,
254+
"Return type `{}` must match vector type `{}`",
255+
dest.layout.ty, input.layout.ty
256+
);
257+
assert_eq!(
258+
scalar.layout.ty, e_ty,
259+
"Scalar type `{}` must match vector element type `{}`",
260+
scalar.layout.ty, e_ty
261+
);
262+
263+
for i in 0..len {
264+
let place = self.place_field(dest, index)?;
265+
if i == index {
266+
self.write_immediate(*scalar, place)?;
267+
} else {
268+
self.write_immediate(
269+
*self.read_immediate(self.operand_field(input, index)?)?,
270+
place
271+
)?;
272+
};
254273
}
255-
self.write_vector(vector, dest)?;
256274
}
257275
"simd_extract" => {
258276
let index = self.read_scalar(args[1])?.to_u32()? as _;
259-
let scalar = self.read_immediate(self.operand_field(args[0], index)?)?;
260-
self.write_immediate(*scalar, dest)?;
277+
let (len, e_ty) = self.read_vector_ty(args[0]);
278+
assert!(
279+
index < len,
280+
"index `{}` must be in bounds of vector type `{}`: `[0, {})`",
281+
index, e_ty, len
282+
);
283+
assert_eq!(
284+
e_ty, dest.layout.ty,
285+
"Return type `{}` must match vector element type `{}`",
286+
dest.layout.ty, e_ty
287+
);
288+
self.write_immediate(
289+
*self.read_immediate(self.operand_field(args[0], index)?)?,
290+
dest
291+
)?;
261292
}
262293
_ => return Ok(false),
263294
}

src/librustc_mir/interpret/operand.rs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -335,18 +335,15 @@ impl<'mir, 'tcx, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
335335
}
336336
}
337337

338-
/// Read vector from operand `op`
339-
pub fn read_vector(&self, op: OpTy<'tcx, M::PointerTag>)
340-
-> InterpResult<'tcx, Vec<ImmTy<'tcx, M::PointerTag>>> {
341-
if let layout::Abi::Vector { count, .. } = op.layout.abi {
342-
assert_ne!(count, 0);
343-
let mut scalars = Vec::new();
344-
for index in 0..count {
345-
scalars.push(self.read_immediate(self.operand_field(op, index as _)?)?);
346-
}
347-
Ok(scalars)
338+
/// Read vector length and element type
339+
pub fn read_vector_ty(
340+
&self, op: OpTy<'tcx, M::PointerTag>
341+
)
342+
-> (u64, &rustc::ty::TyS<'tcx>) {
343+
if let layout::Abi::Vector { .. } = op.layout.abi {
344+
(op.layout.ty.simd_size(*self.tcx) as _, op.layout.ty.simd_type(*self.tcx))
348345
} else {
349-
bug!("type is not a vector: {:?}, abi: {:?}", op.layout.ty, op.layout.abi);
346+
bug!("Type `{}` is not a SIMD vector type", op.layout.ty)
350347
}
351348
}
352349

src/librustc_mir/interpret/place.rs

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -696,40 +696,6 @@ where
696696
Ok(())
697697
}
698698

699-
/// Writes the `scalar` to the `index`-th element of the `vector`.
700-
pub fn write_scalar_to_vector(
701-
&mut self,
702-
scalar: ImmTy<'tcx, M::PointerTag>,
703-
vector: PlaceTy<'tcx, M::PointerTag>,
704-
index: usize,
705-
) -> InterpResult<'tcx> {
706-
let index = index as u64;
707-
let place = self.place_field(vector, index)?;
708-
self.write_immediate(*scalar, place)?;
709-
Ok(())
710-
}
711-
712-
/// Writes the `scalars` to the `vector`.
713-
pub fn write_vector(
714-
&mut self,
715-
scalars: Vec<ImmTy<'tcx, M::PointerTag>>,
716-
vector: PlaceTy<'tcx, M::PointerTag>,
717-
) -> InterpResult<'tcx> {
718-
assert_ne!(scalars.len(), 0);
719-
match vector.layout.ty.sty {
720-
ty::Adt(def, ..) if def.repr.simd() => {
721-
let tcx = &*self.tcx;
722-
let count = vector.layout.ty.simd_size(*tcx);
723-
assert_eq!(count, scalars.len());
724-
for index in 0..scalars.len() {
725-
self.write_scalar_to_vector(scalars[index], vector, index)?;
726-
}
727-
}
728-
_ => bug!("not a vector"),
729-
}
730-
Ok(())
731-
}
732-
733699
/// Write an `Immediate` to memory.
734700
#[inline(always)]
735701
pub fn write_immediate_to_mplace(

src/test/ui/consts/const-eval/simd/insert_extract-fail.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,21 @@
77

88
extern "platform-intrinsic" {
99
fn simd_insert<T, U>(x: T, idx: u32, val: U) -> T;
10+
fn simd_extract<T, U>(x: T, idx: u32) -> U;
1011
}
1112

1213
const fn foo(x: i8x1) -> i8 {
1314
// 42 is a i16 that does not fit in a i8
1415
unsafe { simd_insert(x, 0_u32, 42_i16) }.0 //~ ERROR
1516
}
1617

18+
const fn bar(x: i8x1) -> i16 {
19+
// the i8 is not a i16:
20+
unsafe { simd_extract(x, 0_u32) } //~ ERROR
21+
}
22+
1723
fn main() {
1824
const V: i8x1 = i8x1(13);
1925
const X: i8 = foo(V);
26+
const Y: i16 = bar(V);
2027
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#![feature(const_fn)]
2+
#![feature(platform_intrinsics)]
3+
#![allow(non_camel_case_types)]
4+
5+
extern "platform-intrinsic" {
6+
fn simd_extract<T, U>(x: T, idx: u32) -> U;
7+
}
8+
9+
const fn foo(x: i8) -> i8 {
10+
// i8 is not a vector type:
11+
unsafe { simd_extract(x, 0_u32) } //~ ERROR
12+
}
13+
14+
fn main() {
15+
const V: i8 = 13;
16+
const X: i8 = foo(V);
17+
}

0 commit comments

Comments
 (0)