Skip to content

Commit a654186

Browse files
committed
Implement simd_select_bitmask
1 parent ddc152b commit a654186

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

src/builder.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,17 +1438,18 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
14381438

14391439
pub fn vector_select(&mut self, cond: RValue<'gcc>, then_val: RValue<'gcc>, else_val: RValue<'gcc>) -> RValue<'gcc> {
14401440
// cond is a vector of integers, not of bools.
1441-
let vector_type = cond.get_type().dyncast_vector().expect("vector type");
1441+
let cond_type = cond.get_type();
1442+
let vector_type = cond_type.unqualified().dyncast_vector().expect("vector type");
14421443
let num_units = vector_type.get_num_units();
1443-
let vector_type = self.context.new_vector_type(self.int_type, num_units as u64);
1444-
let zeros = vec![self.context.new_rvalue_zero(self.int_type); num_units];
1445-
let zeros = self.context.new_rvalue_from_vector(None, vector_type, &zeros);
1444+
let element_type = vector_type.get_element_type();
1445+
let zeros = vec![self.context.new_rvalue_zero(element_type); num_units];
1446+
let zeros = self.context.new_rvalue_from_vector(None, cond_type, &zeros);
14461447

14471448
let masks = self.context.new_comparison(None, ComparisonOp::NotEquals, cond, zeros);
14481449
let then_vals = masks & then_val;
14491450

1450-
let ones = vec![self.context.new_rvalue_one(self.int_type); num_units];
1451-
let ones = self.context.new_rvalue_from_vector(None, vector_type, &ones);
1451+
let ones = vec![self.context.new_rvalue_one(element_type); num_units];
1452+
let ones = self.context.new_rvalue_from_vector(None, cond_type, &ones);
14521453
let inverted_masks = masks + ones;
14531454
let else_vals = inverted_masks & else_val;
14541455

src/intrinsic/simd.rs

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ use gccjit::{BinaryOp, RValue, Type, ToRValue};
44
use rustc_codegen_ssa::base::compare_simd_types;
55
use rustc_codegen_ssa::common::{TypeKind, span_invalid_monomorphization_error};
66
use rustc_codegen_ssa::mir::operand::OperandRef;
7+
use rustc_codegen_ssa::mir::place::PlaceRef;
78
use rustc_codegen_ssa::traits::{BaseTypeMethods, BuilderMethods};
89
use rustc_hir as hir;
910
use rustc_middle::span_bug;
1011
use rustc_middle::ty::layout::HasTyCtxt;
1112
use rustc_middle::ty::{self, Ty};
1213
use rustc_span::{Span, Symbol, sym};
14+
use rustc_target::abi::Align;
1315

1416
use crate::builder::Builder;
1517
use crate::intrinsic;
@@ -55,7 +57,53 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(bx: &mut Builder<'a, 'gcc, 'tcx>,
5557
let sig =
5658
tcx.normalize_erasing_late_bound_regions(ty::ParamEnv::reveal_all(), callee_ty.fn_sig(tcx));
5759
let arg_tys = sig.inputs();
58-
let name_str = name.as_str();
60+
61+
if name == sym::simd_select_bitmask {
62+
require_simd!(arg_tys[1], "argument");
63+
let (len, _) = arg_tys[1].simd_size_and_type(bx.tcx());
64+
65+
let expected_int_bits = (len.max(8) - 1).next_power_of_two();
66+
let expected_bytes = len / 8 + ((len % 8 > 0) as u64);
67+
68+
let mask_ty = arg_tys[0];
69+
let mut mask = match mask_ty.kind() {
70+
ty::Int(i) if i.bit_width() == Some(expected_int_bits) => args[0].immediate(),
71+
ty::Uint(i) if i.bit_width() == Some(expected_int_bits) => args[0].immediate(),
72+
ty::Array(elem, len)
73+
if matches!(elem.kind(), ty::Uint(ty::UintTy::U8))
74+
&& len.try_eval_usize(bx.tcx, ty::ParamEnv::reveal_all())
75+
== Some(expected_bytes) =>
76+
{
77+
let place = PlaceRef::alloca(bx, args[0].layout);
78+
args[0].val.store(bx, place);
79+
let int_ty = bx.type_ix(expected_bytes * 8);
80+
let ptr = bx.pointercast(place.llval, bx.cx.type_ptr_to(int_ty));
81+
bx.load(int_ty, ptr, Align::ONE)
82+
}
83+
_ => return_error!(
84+
"invalid bitmask `{}`, expected `u{}` or `[u8; {}]`",
85+
mask_ty,
86+
expected_int_bits,
87+
expected_bytes
88+
),
89+
};
90+
91+
let arg1 = args[1].immediate();
92+
let arg1_type = arg1.get_type();
93+
let arg1_vector_type = arg1_type.unqualified().dyncast_vector().expect("vector type");
94+
let arg1_element_type = arg1_vector_type.get_element_type();
95+
96+
let mut elements = vec![];
97+
let one = bx.context.new_rvalue_one(mask.get_type());
98+
for _ in 0..len {
99+
let element = bx.context.new_cast(None, mask & one, arg1_element_type);
100+
elements.push(element);
101+
mask = mask >> one;
102+
}
103+
let vector_mask = bx.context.new_rvalue_from_vector(None, arg1_type, &elements);
104+
105+
return Ok(bx.vector_select(vector_mask, arg1, args[2].immediate()));
106+
}
59107

60108
// every intrinsic below takes a SIMD vector as its first argument
61109
require_simd!(arg_tys[0], "input");
@@ -102,7 +150,7 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(bx: &mut Builder<'a, 'gcc, 'tcx>,
102150
));
103151
}
104152

105-
if let Some(stripped) = name_str.strip_prefix("simd_shuffle") {
153+
if let Some(stripped) = name.as_str().strip_prefix("simd_shuffle") {
106154
let n: u64 =
107155
if stripped.is_empty() {
108156
// Make sure this is actually an array, since typeck only checks the length-suffixed

0 commit comments

Comments
 (0)