Skip to content

Commit ddc152b

Browse files
committed
Add more SIMD
1 parent 4636c59 commit ddc152b

File tree

5 files changed

+204
-31
lines changed

5 files changed

+204
-31
lines changed

src/asm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ fn reg_to_gcc(reg: InlineAsmRegOrRegClass) -> ConstraintOrRegister {
595595
InlineAsmRegClass::X86(X86InlineAsmRegClass::xmm_reg)
596596
| InlineAsmRegClass::X86(X86InlineAsmRegClass::ymm_reg) => "x",
597597
InlineAsmRegClass::X86(X86InlineAsmRegClass::zmm_reg) => "v",
598-
InlineAsmRegClass::X86(X86InlineAsmRegClass::kreg) => unimplemented!(),
598+
InlineAsmRegClass::X86(X86InlineAsmRegClass::kreg) => "Yk",
599599
InlineAsmRegClass::Wasm(WasmInlineAsmRegClass::local) => unimplemented!(),
600600
InlineAsmRegClass::X86(
601601
X86InlineAsmRegClass::x87_reg | X86InlineAsmRegClass::mmx_reg,

src/builder.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use gccjit::{
77
BinaryOp,
88
Block,
99
ComparisonOp,
10+
Context,
1011
Function,
1112
LValue,
1213
RValue,
@@ -1380,6 +1381,85 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
13801381
pub fn shuffle_vector(&mut self, _v1: RValue<'gcc>, _v2: RValue<'gcc>, _mask: RValue<'gcc>) -> RValue<'gcc> {
13811382
unimplemented!();
13821383
}
1384+
1385+
pub fn vector_reduce<F>(&mut self, src: RValue<'gcc>, op: F) -> RValue<'gcc>
1386+
where F: Fn(RValue<'gcc>, RValue<'gcc>, &'gcc Context<'gcc>) -> RValue<'gcc>
1387+
{
1388+
let vector_type = src.get_type().unqualified().dyncast_vector().expect("vector type");
1389+
let element_count = vector_type.get_num_units();
1390+
let mut vector_elements = vec![];
1391+
for i in 0..element_count {
1392+
vector_elements.push(i);
1393+
}
1394+
let mask_type = self.context.new_vector_type(self.int_type, element_count as u64);
1395+
let mut shift = 1;
1396+
let mut res = src;
1397+
while shift < element_count {
1398+
let vector_elements: Vec<_> =
1399+
vector_elements.iter()
1400+
.map(|i| self.context.new_rvalue_from_int(self.int_type, ((i + shift) % element_count) as i32))
1401+
.collect();
1402+
let mask = self.context.new_rvalue_from_vector(None, mask_type, &vector_elements);
1403+
let shifted = self.context.new_rvalue_vector_perm(None, res, res, mask);
1404+
shift *= 2;
1405+
res = op(res, shifted, &self.context);
1406+
}
1407+
self.context.new_vector_access(None, res, self.context.new_rvalue_zero(self.int_type))
1408+
.to_rvalue()
1409+
}
1410+
1411+
pub fn vector_reduce_op(&mut self, src: RValue<'gcc>, op: BinaryOp) -> RValue<'gcc> {
1412+
self.vector_reduce(src, |a, b, context| context.new_binary_op(None, op, a.get_type(), a, b))
1413+
}
1414+
1415+
pub fn vector_reduce_fadd_fast(&mut self, acc: RValue<'gcc>, src: RValue<'gcc>) -> RValue<'gcc> {
1416+
unimplemented!();
1417+
}
1418+
1419+
pub fn vector_reduce_fmul_fast(&mut self, acc: RValue<'gcc>, src: RValue<'gcc>) -> RValue<'gcc> {
1420+
unimplemented!();
1421+
}
1422+
1423+
// Inspired by Hacker's Delight min implementation.
1424+
pub fn vector_reduce_min(&mut self, src: RValue<'gcc>) -> RValue<'gcc> {
1425+
self.vector_reduce(src, |a, b, context| {
1426+
let differences_or_zeros = difference_or_zero(a, b, context);
1427+
context.new_binary_op(None, BinaryOp::Minus, a.get_type(), a, differences_or_zeros)
1428+
})
1429+
}
1430+
1431+
// Inspired by Hacker's Delight max implementation.
1432+
pub fn vector_reduce_max(&mut self, src: RValue<'gcc>) -> RValue<'gcc> {
1433+
self.vector_reduce(src, |a, b, context| {
1434+
let differences_or_zeros = difference_or_zero(a, b, context);
1435+
context.new_binary_op(None, BinaryOp::Plus, b.get_type(), b, differences_or_zeros)
1436+
})
1437+
}
1438+
1439+
pub fn vector_select(&mut self, cond: RValue<'gcc>, then_val: RValue<'gcc>, else_val: RValue<'gcc>) -> RValue<'gcc> {
1440+
// cond is a vector of integers, not of bools.
1441+
let vector_type = cond.get_type().dyncast_vector().expect("vector type");
1442+
let num_units = vector_type.get_num_units();
1443+
let vector_type = self.context.new_vector_type(self.int_type, num_units as u64);
1444+
let zeros = vec![self.context.new_rvalue_zero(self.int_type); num_units];
1445+
let zeros = self.context.new_rvalue_from_vector(None, vector_type, &zeros);
1446+
1447+
let masks = self.context.new_comparison(None, ComparisonOp::NotEquals, cond, zeros);
1448+
let then_vals = masks & then_val;
1449+
1450+
let ones = vec![self.context.new_rvalue_one(self.int_type); num_units];
1451+
let ones = self.context.new_rvalue_from_vector(None, vector_type, &ones);
1452+
let inverted_masks = masks + ones;
1453+
let else_vals = inverted_masks & else_val;
1454+
1455+
then_vals | else_vals
1456+
}
1457+
}
1458+
1459+
fn difference_or_zero<'gcc>(a: RValue<'gcc>, b: RValue<'gcc>, context: &'gcc Context<'gcc>) -> RValue<'gcc> {
1460+
let difference = a - b;
1461+
let masks = context.new_comparison(None, ComparisonOp::GreaterThanEquals, b, a);
1462+
difference & masks
13831463
}
13841464

13851465
impl<'a, 'gcc, 'tcx> StaticBuilderMethods for Builder<'a, 'gcc, 'tcx> {

src/common.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ impl<'gcc, 'tcx> ConstMethods<'tcx> for CodegenCx<'gcc, 'tcx> {
117117
unimplemented!();
118118
}
119119

120-
fn const_real(&self, _t: Type<'gcc>, _val: f64) -> RValue<'gcc> {
121-
unimplemented!();
120+
fn const_real(&self, typ: Type<'gcc>, val: f64) -> RValue<'gcc> {
121+
self.context.new_rvalue_from_double(typ, val)
122122
}
123123

124124
fn const_str(&self, s: Symbol) -> (RValue<'gcc>, RValue<'gcc>) {

src/intrinsic/simd.rs

Lines changed: 117 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::cmp::Ordering;
22

3-
use gccjit::{RValue, Type, ToRValue};
3+
use gccjit::{BinaryOp, RValue, Type, ToRValue};
44
use rustc_codegen_ssa::base::compare_simd_types;
55
use rustc_codegen_ssa::common::{TypeKind, span_invalid_monomorphization_error};
66
use rustc_codegen_ssa::mir::operand::OperandRef;
@@ -222,6 +222,24 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(bx: &mut Builder<'a, 'gcc, 'tcx>,
222222
return Ok(bx.context.new_vector_access(None, vector, args[1].immediate()).to_rvalue());
223223
}
224224

225+
if name == sym::simd_select {
226+
let m_elem_ty = in_elem;
227+
let m_len = in_len;
228+
require_simd!(arg_tys[1], "argument");
229+
let (v_len, _) = arg_tys[1].simd_size_and_type(bx.tcx());
230+
require!(
231+
m_len == v_len,
232+
"mismatched lengths: mask length `{}` != other vector length `{}`",
233+
m_len,
234+
v_len
235+
);
236+
match m_elem_ty.kind() {
237+
ty::Int(_) => {}
238+
_ => return_error!("mask element type is `{}`, expected `i_`", m_elem_ty),
239+
}
240+
return Ok(bx.vector_select(args[0].immediate(), args[1].immediate(), args[2].immediate()));
241+
}
242+
225243
if name == sym::simd_cast {
226244
require_simd!(ret_ty, "return");
227245
let (out_len, out_elem) = ret_ty.simd_size_and_type(bx.tcx());
@@ -543,7 +561,7 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(bx: &mut Builder<'a, 'gcc, 'tcx>,
543561
}
544562

545563
macro_rules! arith_red {
546-
($name:ident : $integer_reduce:ident, $float_reduce:ident, $ordered:expr, $op:ident,
564+
($name:ident : $vec_op:expr, $float_reduce:ident, $ordered:expr, $op:ident,
547565
$identity:expr) => {
548566
if name == sym::$name {
549567
require!(
@@ -555,36 +573,25 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(bx: &mut Builder<'a, 'gcc, 'tcx>,
555573
);
556574
return match in_elem.kind() {
557575
ty::Int(_) | ty::Uint(_) => {
558-
let r = bx.$integer_reduce(args[0].immediate());
576+
let r = bx.vector_reduce_op(args[0].immediate(), $vec_op);
559577
if $ordered {
560578
// if overflow occurs, the result is the
561579
// mathematical result modulo 2^n:
562580
Ok(bx.$op(args[1].immediate(), r))
563-
} else {
564-
Ok(bx.$integer_reduce(args[0].immediate()))
581+
}
582+
else {
583+
Ok(bx.vector_reduce_op(args[0].immediate(), $vec_op))
565584
}
566585
}
567-
ty::Float(f) => {
568-
let acc = if $ordered {
586+
ty::Float(_) => {
587+
if $ordered {
569588
// ordered arithmetic reductions take an accumulator
570-
args[1].immediate()
571-
} else {
572-
// unordered arithmetic reductions use the identity accumulator
573-
match f.bit_width() {
574-
32 => bx.const_real(bx.type_f32(), $identity),
575-
64 => bx.const_real(bx.type_f64(), $identity),
576-
v => return_error!(
577-
r#"
578-
unsupported {} from `{}` with element `{}` of size `{}` to `{}`"#,
579-
sym::$name,
580-
in_ty,
581-
in_elem,
582-
v,
583-
ret_ty
584-
),
585-
}
586-
};
587-
Ok(bx.$float_reduce(acc, args[0].immediate()))
589+
let acc = args[1].immediate();
590+
Ok(bx.$float_reduce(acc, args[0].immediate()))
591+
}
592+
else {
593+
Ok(bx.vector_reduce_op(args[0].immediate(), $vec_op))
594+
}
588595
}
589596
_ => return_error!(
590597
"unsupported {} from `{}` with element `{}` to `{}`",
@@ -598,14 +605,96 @@ unsupported {} from `{}` with element `{}` of size `{}` to `{}`"#,
598605
};
599606
}
600607

601-
// TODO: use a recursive algorithm a-la Hacker's Delight.
602608
arith_red!(
603-
simd_reduce_add_unordered: vector_reduce_add,
609+
simd_reduce_add_unordered: BinaryOp::Plus,
604610
vector_reduce_fadd_fast,
605611
false,
606612
add,
607-
0.0
613+
0.0 // TODO: Use this argument.
614+
);
615+
arith_red!(
616+
simd_reduce_mul_unordered: BinaryOp::Mult,
617+
vector_reduce_fmul_fast,
618+
false,
619+
mul,
620+
1.0
608621
);
609622

623+
macro_rules! minmax_red {
624+
($name:ident: $reduction:ident) => {
625+
if name == sym::$name {
626+
require!(
627+
ret_ty == in_elem,
628+
"expected return type `{}` (element of input `{}`), found `{}`",
629+
in_elem,
630+
in_ty,
631+
ret_ty
632+
);
633+
return match in_elem.kind() {
634+
ty::Int(_) | ty::Uint(_) | ty::Float(_) => Ok(bx.$reduction(args[0].immediate())),
635+
_ => return_error!(
636+
"unsupported {} from `{}` with element `{}` to `{}`",
637+
sym::$name,
638+
in_ty,
639+
in_elem,
640+
ret_ty
641+
),
642+
};
643+
}
644+
};
645+
}
646+
647+
minmax_red!(simd_reduce_min: vector_reduce_min);
648+
minmax_red!(simd_reduce_max: vector_reduce_max);
649+
650+
macro_rules! bitwise_red {
651+
($name:ident : $op:expr, $boolean:expr) => {
652+
if name == sym::$name {
653+
let input = if !$boolean {
654+
require!(
655+
ret_ty == in_elem,
656+
"expected return type `{}` (element of input `{}`), found `{}`",
657+
in_elem,
658+
in_ty,
659+
ret_ty
660+
);
661+
args[0].immediate()
662+
} else {
663+
match in_elem.kind() {
664+
ty::Int(_) | ty::Uint(_) => {}
665+
_ => return_error!(
666+
"unsupported {} from `{}` with element `{}` to `{}`",
667+
sym::$name,
668+
in_ty,
669+
in_elem,
670+
ret_ty
671+
),
672+
}
673+
674+
// boolean reductions operate on vectors of i1s:
675+
let i1 = bx.type_i1();
676+
let i1xn = bx.type_vector(i1, in_len as u64);
677+
bx.trunc(args[0].immediate(), i1xn)
678+
};
679+
return match in_elem.kind() {
680+
ty::Int(_) | ty::Uint(_) => {
681+
let r = bx.vector_reduce_op(input, $op);
682+
Ok(if !$boolean { r } else { bx.zext(r, bx.type_bool()) })
683+
}
684+
_ => return_error!(
685+
"unsupported {} from `{}` with element `{}` to `{}`",
686+
sym::$name,
687+
in_ty,
688+
in_elem,
689+
ret_ty
690+
),
691+
};
692+
}
693+
};
694+
}
695+
696+
bitwise_red!(simd_reduce_and: BinaryOp::BitwiseAnd, false);
697+
bitwise_red!(simd_reduce_or: BinaryOp::BitwiseOr, false);
698+
610699
unimplemented!("simd {}", name);
611700
}

src/type_.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ impl<'gcc, 'tcx> CodegenCx<'gcc, 'tcx> {
247247

248248
self.context.new_array_type(None, ty, len)
249249
}
250+
251+
pub fn type_bool(&self) -> Type<'gcc> {
252+
self.context.new_type::<bool>()
253+
}
250254
}
251255

252256
pub fn struct_fields<'gcc, 'tcx>(cx: &CodegenCx<'gcc, 'tcx>, layout: TyAndLayout<'tcx>) -> (Vec<Type<'gcc>>, bool) {

0 commit comments

Comments
 (0)