Skip to content

Commit 01cc5b3

Browse files
committed
add intrinsics for portable packed simd vector reductions
1 parent e5acb0c commit 01cc5b3

File tree

6 files changed

+525
-3
lines changed

6 files changed

+525
-3
lines changed

src/librustc_llvm/ffi.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,46 @@ extern "C" {
12011201
Name: *const c_char)
12021202
-> ValueRef;
12031203

1204+
pub fn LLVMRustBuildVectorReduceFAdd(B: BuilderRef,
1205+
Acc: ValueRef,
1206+
Src: ValueRef)
1207+
-> ValueRef;
1208+
pub fn LLVMRustBuildVectorReduceFMul(B: BuilderRef,
1209+
Acc: ValueRef,
1210+
Src: ValueRef)
1211+
-> ValueRef;
1212+
pub fn LLVMRustBuildVectorReduceAdd(B: BuilderRef,
1213+
Src: ValueRef)
1214+
-> ValueRef;
1215+
pub fn LLVMRustBuildVectorReduceMul(B: BuilderRef,
1216+
Src: ValueRef)
1217+
-> ValueRef;
1218+
pub fn LLVMRustBuildVectorReduceAnd(B: BuilderRef,
1219+
Src: ValueRef)
1220+
-> ValueRef;
1221+
pub fn LLVMRustBuildVectorReduceOr(B: BuilderRef,
1222+
Src: ValueRef)
1223+
-> ValueRef;
1224+
pub fn LLVMRustBuildVectorReduceXor(B: BuilderRef,
1225+
Src: ValueRef)
1226+
-> ValueRef;
1227+
pub fn LLVMRustBuildVectorReduceMin(B: BuilderRef,
1228+
Src: ValueRef,
1229+
IsSigned: bool)
1230+
-> ValueRef;
1231+
pub fn LLVMRustBuildVectorReduceMax(B: BuilderRef,
1232+
Src: ValueRef,
1233+
IsSigned: bool)
1234+
-> ValueRef;
1235+
pub fn LLVMRustBuildVectorReduceFMin(B: BuilderRef,
1236+
Src: ValueRef,
1237+
IsNaN: bool)
1238+
-> ValueRef;
1239+
pub fn LLVMRustBuildVectorReduceFMax(B: BuilderRef,
1240+
Src: ValueRef,
1241+
IsNaN: bool)
1242+
-> ValueRef;
1243+
12041244
pub fn LLVMBuildIsNull(B: BuilderRef, Val: ValueRef, Name: *const c_char) -> ValueRef;
12051245
pub fn LLVMBuildIsNotNull(B: BuilderRef, Val: ValueRef, Name: *const c_char) -> ValueRef;
12061246
pub fn LLVMBuildPtrDiff(B: BuilderRef,

src/librustc_trans/builder.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,81 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
955955
}
956956
}
957957

958+
pub fn vector_reduce_fadd_fast(&self, acc: ValueRef, src: ValueRef) -> ValueRef {
959+
self.count_insn("vector.reduce.fadd_fast");
960+
unsafe {
961+
let instr = llvm::LLVMRustBuildVectorReduceFAdd(self.llbuilder, acc, src);
962+
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
963+
instr
964+
}
965+
}
966+
pub fn vector_reduce_fmul_fast(&self, acc: ValueRef, src: ValueRef) -> ValueRef {
967+
self.count_insn("vector.reduce.fmul_fast");
968+
unsafe {
969+
let instr = llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src);
970+
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
971+
instr
972+
}
973+
}
974+
pub fn vector_reduce_add(&self, src: ValueRef) -> ValueRef {
975+
self.count_insn("vector.reduce.add");
976+
unsafe {
977+
llvm::LLVMRustBuildVectorReduceAdd(self.llbuilder, src)
978+
}
979+
}
980+
pub fn vector_reduce_mul(&self, src: ValueRef) -> ValueRef {
981+
self.count_insn("vector.reduce.mul");
982+
unsafe {
983+
llvm::LLVMRustBuildVectorReduceMul(self.llbuilder, src)
984+
}
985+
}
986+
pub fn vector_reduce_and(&self, src: ValueRef) -> ValueRef {
987+
self.count_insn("vector.reduce.and");
988+
unsafe {
989+
llvm::LLVMRustBuildVectorReduceAnd(self.llbuilder, src)
990+
}
991+
}
992+
pub fn vector_reduce_or(&self, src: ValueRef) -> ValueRef {
993+
self.count_insn("vector.reduce.or");
994+
unsafe {
995+
llvm::LLVMRustBuildVectorReduceOr(self.llbuilder, src)
996+
}
997+
}
998+
pub fn vector_reduce_xor(&self, src: ValueRef) -> ValueRef {
999+
self.count_insn("vector.reduce.xor");
1000+
unsafe {
1001+
llvm::LLVMRustBuildVectorReduceXor(self.llbuilder, src)
1002+
}
1003+
}
1004+
pub fn vector_reduce_fmin_fast(&self, src: ValueRef) -> ValueRef {
1005+
self.count_insn("vector.reduce.fmin_fast");
1006+
unsafe {
1007+
let instr = llvm::LLVMRustBuildVectorReduceFMin(self.llbuilder, src, false);
1008+
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
1009+
instr
1010+
}
1011+
}
1012+
pub fn vector_reduce_fmax_fast(&self, src: ValueRef) -> ValueRef {
1013+
self.count_insn("vector.reduce.fmax_fast");
1014+
unsafe {
1015+
let instr = llvm::LLVMRustBuildVectorReduceFMax(self.llbuilder, src, false);
1016+
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
1017+
instr
1018+
}
1019+
}
1020+
pub fn vector_reduce_min(&self, src: ValueRef, is_signed: bool) -> ValueRef {
1021+
self.count_insn("vector.reduce.min");
1022+
unsafe {
1023+
llvm::LLVMRustBuildVectorReduceMin(self.llbuilder, src, is_signed)
1024+
}
1025+
}
1026+
pub fn vector_reduce_max(&self, src: ValueRef, is_signed: bool) -> ValueRef {
1027+
self.count_insn("vector.reduce.max");
1028+
unsafe {
1029+
llvm::LLVMRustBuildVectorReduceMax(self.llbuilder, src, is_signed)
1030+
}
1031+
}
1032+
9581033
pub fn extract_value(&self, agg_val: ValueRef, idx: u64) -> ValueRef {
9591034
self.count_insn("extractvalue");
9601035
assert_eq!(idx as c_uint as u64, idx);

src/librustc_trans/intrinsic.rs

Lines changed: 216 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,14 +1018,22 @@ fn generic_simd_intrinsic<'a, 'tcx>(
10181018
name, $($fmt)*));
10191019
}
10201020
}
1021-
macro_rules! require {
1022-
($cond: expr, $($fmt: tt)*) => {
1023-
if !$cond {
1021+
macro_rules! return_error {
1022+
($($fmt: tt)*) => {
1023+
{
10241024
emit_error!($($fmt)*);
10251025
return Err(());
10261026
}
10271027
}
10281028
}
1029+
1030+
macro_rules! require {
1031+
($cond: expr, $($fmt: tt)*) => {
1032+
if !$cond {
1033+
return_error!($($fmt)*);
1034+
}
1035+
};
1036+
}
10291037
macro_rules! require_simd {
10301038
($ty: expr, $position: expr) => {
10311039
require!($ty.is_simd(), "expected SIMD {} type, found non-SIMD `{}`", $position, $ty)
@@ -1142,6 +1150,211 @@ fn generic_simd_intrinsic<'a, 'tcx>(
11421150
return Ok(bx.extract_element(args[0].immediate(), args[1].immediate()))
11431151
}
11441152

1153+
if name == "simd_reduce_add" {
1154+
require!(ret_ty == in_elem,
1155+
"expected return type `{}` (element of input `{}`), found `{}`",
1156+
in_elem, in_ty, ret_ty);
1157+
return match in_elem.sty {
1158+
ty::TyInt(_i) => {
1159+
Ok(bx.vector_reduce_add(args[0].immediate()))
1160+
},
1161+
ty::TyUint(_u) => {
1162+
Ok(bx.vector_reduce_add(args[0].immediate()))
1163+
},
1164+
ty::TyFloat(f) => {
1165+
// undef as accumulator makes the reduction unordered:
1166+
let acc = match f.bit_width() {
1167+
32 => C_undef(Type::f32(bx.cx)),
1168+
64 => C_undef(Type::f64(bx.cx)),
1169+
v => {
1170+
return_error!(
1171+
"unsupported {} from `{}` with element `{}` of size `{}` to `{}`",
1172+
"simd_reduce_add", in_ty, in_elem, v, ret_ty)
1173+
}
1174+
};
1175+
Ok(bx.vector_reduce_fadd_fast(acc, args[0].immediate()))
1176+
}
1177+
_ => {
1178+
return_error!("unsupported {} from `{}` with element `{}` to `{}`",
1179+
"simd_reduce_add", in_ty, in_elem, ret_ty)
1180+
},
1181+
}
1182+
}
1183+
1184+
if name == "simd_reduce_mul" {
1185+
require!(ret_ty == in_elem,
1186+
"expected return type `{}` (element of input `{}`), found `{}`",
1187+
in_elem, in_ty, ret_ty);
1188+
return match in_elem.sty {
1189+
ty::TyInt(_i) => {
1190+
Ok(bx.vector_reduce_mul(args[0].immediate()))
1191+
},
1192+
ty::TyUint(_u) => {
1193+
Ok(bx.vector_reduce_mul(args[0].immediate()))
1194+
},
1195+
ty::TyFloat(f) => {
1196+
// undef as accumulator makes the reduction unordered:
1197+
let acc = match f.bit_width() {
1198+
32 => C_undef(Type::f32(bx.cx)),
1199+
64 => C_undef(Type::f64(bx.cx)),
1200+
v => {
1201+
return_error!(
1202+
"unsupported {} from `{}` with element `{}` of size `{}` to `{}`",
1203+
"simd_reduce_mul", in_ty, in_elem, v, ret_ty)
1204+
}
1205+
};
1206+
Ok(bx.vector_reduce_fmul_fast(acc, args[0].immediate()))
1207+
}
1208+
_ => {
1209+
return_error!("unsupported {} from `{}` with element `{}` to `{}`",
1210+
"simd_reduce_mul", in_ty, in_elem, ret_ty)
1211+
},
1212+
}
1213+
}
1214+
1215+
if name == "simd_reduce_min" {
1216+
require!(ret_ty == in_elem,
1217+
"expected return type `{}` (element of input `{}`), found `{}`",
1218+
in_elem, in_ty, ret_ty);
1219+
return match in_elem.sty {
1220+
ty::TyInt(_i) => {
1221+
Ok(bx.vector_reduce_min(args[0].immediate(), true))
1222+
},
1223+
ty::TyUint(_u) => {
1224+
Ok(bx.vector_reduce_min(args[0].immediate(), false))
1225+
},
1226+
ty::TyFloat(_f) => {
1227+
Ok(bx.vector_reduce_fmin_fast(args[0].immediate()))
1228+
}
1229+
_ => {
1230+
return_error!("unsupported {} from `{}` with element `{}` to `{}`",
1231+
"simd_reduce_min", in_ty, in_elem, ret_ty)
1232+
},
1233+
}
1234+
}
1235+
1236+
if name == "simd_reduce_max" {
1237+
require!(ret_ty == in_elem,
1238+
"expected return type `{}` (element of input `{}`), found `{}`",
1239+
in_elem, in_ty, ret_ty);
1240+
return match in_elem.sty {
1241+
ty::TyInt(_i) => {
1242+
Ok(bx.vector_reduce_max(args[0].immediate(), true))
1243+
},
1244+
ty::TyUint(_u) => {
1245+
Ok(bx.vector_reduce_max(args[0].immediate(), false))
1246+
},
1247+
ty::TyFloat(_f) => {
1248+
Ok(bx.vector_reduce_fmax_fast(args[0].immediate()))
1249+
}
1250+
_ => {
1251+
return_error!("unsupported {} from `{}` with element `{}` to `{}`",
1252+
"simd_reduce_max", in_ty, in_elem, ret_ty)
1253+
},
1254+
}
1255+
}
1256+
1257+
if name == "simd_reduce_and" {
1258+
require!(ret_ty == in_elem,
1259+
"expected return type `{}` (element of input `{}`), found `{}`",
1260+
in_elem, in_ty, ret_ty);
1261+
return match in_elem.sty {
1262+
ty::TyInt(_i) => {
1263+
Ok(bx.vector_reduce_and(args[0].immediate()))
1264+
},
1265+
ty::TyUint(_u) => {
1266+
Ok(bx.vector_reduce_and(args[0].immediate()))
1267+
},
1268+
_ => {
1269+
return_error!("unsupported {} from `{}` with element `{}` to `{}`",
1270+
"simd_reduce_and", in_ty, in_elem, ret_ty)
1271+
},
1272+
}
1273+
}
1274+
1275+
if name == "simd_reduce_or" {
1276+
require!(ret_ty == in_elem,
1277+
"expected return type `{}` (element of input `{}`), found `{}`",
1278+
in_elem, in_ty, ret_ty);
1279+
return match in_elem.sty {
1280+
ty::TyInt(_i) => {
1281+
Ok(bx.vector_reduce_or(args[0].immediate()))
1282+
},
1283+
ty::TyUint(_u) => {
1284+
Ok(bx.vector_reduce_or(args[0].immediate()))
1285+
},
1286+
_ => {
1287+
return_error!("unsupported {} from `{}` with element `{}` to `{}`",
1288+
"simd_reduce_or", in_ty, in_elem, ret_ty)
1289+
},
1290+
}
1291+
}
1292+
1293+
if name == "simd_reduce_xor" {
1294+
require!(ret_ty == in_elem,
1295+
"expected return type `{}` (element of input `{}`), found `{}`",
1296+
in_elem, in_ty, ret_ty);
1297+
return match in_elem.sty {
1298+
ty::TyInt(_i) => {
1299+
Ok(bx.vector_reduce_xor(args[0].immediate()))
1300+
},
1301+
ty::TyUint(_u) => {
1302+
Ok(bx.vector_reduce_xor(args[0].immediate()))
1303+
},
1304+
_ => {
1305+
return_error!("unsupported {} from `{}` with element `{}` to `{}`",
1306+
"simd_reduce_xor", in_ty, in_elem, ret_ty)
1307+
},
1308+
}
1309+
}
1310+
1311+
if name == "simd_reduce_all" {
1312+
//require!(ret_ty == in_elem,
1313+
// "expected return type `{}` (element of input `{}`), found `{}`",
1314+
// in_elem, in_ty, ret_ty);
1315+
let i1 = Type::i1(bx.cx);
1316+
let i1xn = Type::vector(&i1, in_len as u64);
1317+
let v = bx.trunc(args[0].immediate(), i1xn);
1318+
1319+
let red = match in_elem.sty {
1320+
ty::TyInt(_i) => {
1321+
bx.vector_reduce_and(v)
1322+
},
1323+
ty::TyUint(_u) => {
1324+
bx.vector_reduce_and(v)
1325+
},
1326+
_ => {
1327+
return_error!("unsupported {} from `{}` with element `{}` to `{}`",
1328+
"simd_reduce_and", in_ty, in_elem, ret_ty)
1329+
},
1330+
};
1331+
return Ok(bx.zext(red, Type::bool(bx.cx)));
1332+
}
1333+
1334+
if name == "simd_reduce_any" {
1335+
//require!(ret_ty == in_elem,
1336+
// "expected return type `{}` (element of input `{}`), found `{}`",
1337+
// in_elem, in_ty, ret_ty);
1338+
let i1 = Type::i1(bx.cx);
1339+
let i1xn = Type::vector(&i1, in_len as u64);
1340+
let v = bx.trunc(args[0].immediate(), i1xn);
1341+
1342+
let red = match in_elem.sty {
1343+
ty::TyInt(_i) => {
1344+
bx.vector_reduce_or(v)
1345+
},
1346+
ty::TyUint(_u) => {
1347+
bx.vector_reduce_or(v)
1348+
},
1349+
_ => {
1350+
return_error!("unsupported {} from `{}` with element `{}` to `{}`",
1351+
"simd_reduce_and", in_ty, in_elem, ret_ty)
1352+
},
1353+
};
1354+
return Ok(bx.zext(red, Type::bool(bx.cx)));
1355+
}
1356+
1357+
11451358
if name == "simd_cast" {
11461359
require_simd!(ret_ty, "return");
11471360
let out_len = ret_ty.simd_size(tcx);

src/librustc_typeck/check/intrinsic.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,11 @@ pub fn check_platform_intrinsic_type<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
361361
"simd_insert" => (2, vec![param(0), tcx.types.u32, param(1)], param(0)),
362362
"simd_extract" => (2, vec![param(0), tcx.types.u32], param(1)),
363363
"simd_cast" => (2, vec![param(0)], param(1)),
364+
"simd_reduce_all" | "simd_reduce_any" => (1, vec![param(0)], tcx.types.bool),
365+
"simd_reduce_add" | "simd_reduce_mul" |
366+
"simd_reduce_and" | "simd_reduce_or" | "simd_reduce_xor" |
367+
"simd_reduce_min" | "simd_reduce_max"
368+
=> (2, vec![param(0)], param(1)),
364369
name if name.starts_with("simd_shuffle") => {
365370
match name["simd_shuffle".len()..].parse() {
366371
Ok(n) => {

0 commit comments

Comments
 (0)