Skip to content

Commit df72765

Browse files
committed
Implement simd_scatter
1 parent 0898eab commit df72765

File tree

1 file changed

+151
-25
lines changed

1 file changed

+151
-25
lines changed

src/intrinsic/simd.rs

Lines changed: 151 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,50 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(bx: &mut Builder<'a, 'gcc, 'tcx>,
519519
cx.type_vector(elem_ty, vec_len)
520520
}
521521

522+
fn gather<'a, 'gcc, 'tcx>(default: RValue<'gcc>, pointers: RValue<'gcc>, mask: RValue<'gcc>, pointer_count: usize, bx: &mut Builder<'a, 'gcc, 'tcx>, in_len: u64, underlying_ty: Ty<'tcx>, invert: bool) -> RValue<'gcc> {
523+
let vector_type =
524+
if pointer_count > 1 {
525+
bx.context.new_vector_type(bx.usize_type, in_len)
526+
}
527+
else {
528+
vector_ty(bx, underlying_ty, in_len)
529+
};
530+
let elem_type = vector_type.dyncast_vector().expect("vector type").get_element_type();
531+
532+
let mut values = vec![];
533+
for i in 0..in_len {
534+
let index = bx.context.new_rvalue_from_long(bx.i32_type, i as i64);
535+
let int = bx.context.new_vector_access(None, pointers, index).to_rvalue();
536+
537+
let ptr_type = elem_type.make_pointer();
538+
let ptr = bx.context.new_bitcast(None, int, ptr_type);
539+
let value = ptr.dereference(None).to_rvalue();
540+
values.push(value);
541+
}
542+
543+
let vector = bx.context.new_rvalue_from_vector(None, vector_type, &values);
544+
545+
let mut mask_types = vec![];
546+
let mut mask_values = vec![];
547+
for i in 0..in_len {
548+
let index = bx.context.new_rvalue_from_long(bx.i32_type, i as i64);
549+
mask_types.push(bx.context.new_field(None, bx.i32_type, "m")); // TODO: choose an integer based on the size of the vector element type.
550+
let mask_value = bx.context.new_vector_access(None, mask, index).to_rvalue();
551+
let masked = bx.context.new_rvalue_from_int(bx.i32_type, in_len as i32) & mask_value;
552+
let value = index + masked;
553+
mask_values.push(value);
554+
}
555+
let mask_type = bx.context.new_struct_type(None, "mask_type", &mask_types);
556+
let mask = bx.context.new_struct_constructor(None, mask_type.as_type(), None, &mask_values);
557+
558+
if invert {
559+
bx.shuffle_vector(vector, default, mask)
560+
}
561+
else {
562+
bx.shuffle_vector(default, vector, mask)
563+
}
564+
}
565+
522566
if name == sym::simd_gather {
523567
// simd_gather(values: <N x T>, pointers: <N x *_ T>,
524568
// mask: <N x i{M}>) -> <N x T>
@@ -616,6 +660,108 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(bx: &mut Builder<'a, 'gcc, 'tcx>,
616660
}
617661
}
618662

663+
return Ok(gather(args[0].immediate(), args[1].immediate(), args[2].immediate(), pointer_count, bx, in_len, underlying_ty, false));
664+
}
665+
666+
if name == sym::simd_scatter {
667+
// simd_scatter(values: <N x T>, pointers: <N x *mut T>,
668+
// mask: <N x i{M}>) -> ()
669+
// * N: number of elements in the input vectors
670+
// * T: type of the element to load
671+
// * M: any integer width is supported, will be truncated to i1
672+
673+
// All types must be simd vector types
674+
require_simd!(in_ty, "first");
675+
require_simd!(arg_tys[1], "second");
676+
require_simd!(arg_tys[2], "third");
677+
678+
// Of the same length:
679+
let (element_len1, _) = arg_tys[1].simd_size_and_type(bx.tcx());
680+
let (element_len2, _) = arg_tys[2].simd_size_and_type(bx.tcx());
681+
require!(
682+
in_len == element_len1,
683+
"expected {} argument with length {} (same as input type `{}`), \
684+
found `{}` with length {}",
685+
"second",
686+
in_len,
687+
in_ty,
688+
arg_tys[1],
689+
element_len1
690+
);
691+
require!(
692+
in_len == element_len2,
693+
"expected {} argument with length {} (same as input type `{}`), \
694+
found `{}` with length {}",
695+
"third",
696+
in_len,
697+
in_ty,
698+
arg_tys[2],
699+
element_len2
700+
);
701+
702+
// This counts how many pointers
703+
fn ptr_count(t: Ty<'_>) -> usize {
704+
match t.kind() {
705+
ty::RawPtr(p) => 1 + ptr_count(p.ty),
706+
_ => 0,
707+
}
708+
}
709+
710+
// Non-ptr type
711+
fn non_ptr(t: Ty<'_>) -> Ty<'_> {
712+
match t.kind() {
713+
ty::RawPtr(p) => non_ptr(p.ty),
714+
_ => t,
715+
}
716+
}
717+
718+
// The second argument must be a simd vector with an element type that's a pointer
719+
// to the element type of the first argument
720+
let (_, element_ty0) = arg_tys[0].simd_size_and_type(bx.tcx());
721+
let (_, element_ty1) = arg_tys[1].simd_size_and_type(bx.tcx());
722+
let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
723+
let (pointer_count, underlying_ty) = match element_ty1.kind() {
724+
ty::RawPtr(p) if p.ty == in_elem && p.mutbl == hir::Mutability::Mut => {
725+
(ptr_count(element_ty1), non_ptr(element_ty1))
726+
}
727+
_ => {
728+
require!(
729+
false,
730+
"expected element type `{}` of second argument `{}` \
731+
to be a pointer to the element type `{}` of the first \
732+
argument `{}`, found `{}` != `*mut {}`",
733+
element_ty1,
734+
arg_tys[1],
735+
in_elem,
736+
in_ty,
737+
element_ty1,
738+
in_elem
739+
);
740+
unreachable!();
741+
}
742+
};
743+
assert!(pointer_count > 0);
744+
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
745+
assert_eq!(underlying_ty, non_ptr(element_ty0));
746+
747+
// The element type of the third argument must be a signed integer type of any width:
748+
match element_ty2.kind() {
749+
ty::Int(_) => (),
750+
_ => {
751+
require!(
752+
false,
753+
"expected element type `{}` of third argument `{}` \
754+
be a signed integer type",
755+
element_ty2,
756+
arg_tys[2]
757+
);
758+
}
759+
}
760+
761+
let result = gather(args[0].immediate(), args[1].immediate(), args[2].immediate(), pointer_count, bx, in_len, underlying_ty, true);
762+
763+
let pointers = args[1].immediate();
764+
619765
let vector_type =
620766
if pointer_count > 1 {
621767
bx.context.new_vector_type(bx.usize_type, in_len)
@@ -625,37 +771,17 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(bx: &mut Builder<'a, 'gcc, 'tcx>,
625771
};
626772
let elem_type = vector_type.dyncast_vector().expect("vector type").get_element_type();
627773

628-
let mut values = vec![];
629-
let pointers = args[1].immediate();
630774
for i in 0..in_len {
631-
let index = bx.context.new_rvalue_from_long(bx.i32_type, i as i64);
632-
let int = bx.context.new_vector_access(None, pointers, index).to_rvalue();
775+
let index = bx.context.new_rvalue_from_int(bx.int_type, i as i32);
776+
let value = bx.context.new_vector_access(None, result, index);
633777

778+
let int = bx.context.new_vector_access(None, pointers, index).to_rvalue();
634779
let ptr_type = elem_type.make_pointer();
635-
636780
let ptr = bx.context.new_bitcast(None, int, ptr_type);
637-
let value = ptr.dereference(None).to_rvalue();
638-
values.push(value);
639-
}
640-
641-
let vector = bx.context.new_rvalue_from_vector(None, vector_type, &values);
642-
let default = args[0].immediate();
643-
let mask = args[2].immediate();
644-
645-
let mut mask_types = vec![];
646-
let mut mask_values = vec![];
647-
for i in 0..in_len {
648-
let index = bx.context.new_rvalue_from_long(bx.i32_type, i as i64);
649-
mask_types.push(bx.context.new_field(None, bx.i32_type, "m")); // TODO: choose an integer based on the size of the vector element type.
650-
let mask_value = bx.context.new_vector_access(None, mask, index).to_rvalue();
651-
let masked = bx.context.new_rvalue_from_int(bx.i32_type, in_len as i32) & mask_value;
652-
let value = index + masked;
653-
mask_values.push(value);
781+
bx.llbb().add_assignment(None, ptr.dereference(None), value);
654782
}
655-
let mask_type = bx.context.new_struct_type(None, "mask_type", &mask_types);
656-
let mask = bx.context.new_struct_constructor(None, mask_type.as_type(), None, &mask_values);
657783

658-
return Ok(bx.shuffle_vector(default, vector, mask));
784+
return Ok(bx.context.new_rvalue_zero(bx.i32_type));
659785
}
660786

661787
arith_binary! {

0 commit comments

Comments
 (0)