Skip to content

Commit 08d03fe

Browse files
committed
Extract common logic for converting vector masks to i1 vectors and make it use the sign bit instead of lowest bit consistently
1 parent 063b1f0 commit 08d03fe

File tree

1 file changed

+39
-40
lines changed

1 file changed

+39
-40
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,20 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
965965
}};
966966
}
967967

968+
fn vector_mask_to_bitmask<'a, 'll, 'tcx>(
969+
bx: &mut Builder<'a, 'll, 'tcx>,
970+
i_xn: &'ll Value,
971+
in_elem_bitwidth: u64,
972+
in_len: u64,
973+
) -> &'ll Value {
974+
// Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
975+
let shift_idx = bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
976+
let shift_indices = vec![shift_idx; in_len as _];
977+
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
978+
// Truncate vector to an <i1 x N>
979+
bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len))
980+
}
981+
968982
let tcx = bx.tcx();
969983
let sig =
970984
tcx.normalize_erasing_late_bound_regions(ty::ParamEnv::reveal_all(), callee_ty.fn_sig(tcx));
@@ -1225,14 +1239,11 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
12251239
m_len == v_len,
12261240
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
12271241
);
1228-
match m_elem_ty.kind() {
1229-
ty::Int(_) => {}
1242+
let in_elem_bitwidth = match m_elem_ty.kind() {
1243+
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
12301244
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
1231-
}
1232-
// truncate the mask to a vector of i1s
1233-
let i1 = bx.type_i1();
1234-
let i1xn = bx.type_vector(i1, m_len as u64);
1235-
let m_i1s = bx.trunc(args[0].immediate(), i1xn);
1245+
};
1246+
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
12361247
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
12371248
}
12381249

@@ -1267,15 +1278,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
12671278
}),
12681279
};
12691280

1270-
// Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
1271-
let shift_indices =
1272-
vec![
1273-
bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
1274-
in_len as _
1275-
];
1276-
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
1277-
// Truncate vector to an <i1 x N>
1278-
let i1xn = bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len));
1281+
let i1xn = vector_mask_to_bitmask(bx, i_xn, in_elem_bitwidth, in_len);
12791282
// Bitcast <i1 x N> to iN:
12801283
let i_ = bx.bitcast(i1xn, bx.type_ix(in_len));
12811284

@@ -1493,28 +1496,25 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14931496
}
14941497
);
14951498

1496-
match element_ty2.kind() {
1497-
ty::Int(_) => (),
1499+
let mask_elem_bitwidth = match element_ty2.kind() {
1500+
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
14981501
_ => {
14991502
return_error!(InvalidMonomorphization::ThirdArgElementType {
15001503
span,
15011504
name,
15021505
expected_element: element_ty2,
15031506
third_arg: arg_tys[2]
1504-
});
1507+
})
15051508
}
1506-
}
1509+
};
15071510

15081511
// Alignment of T, must be a constant integer value:
15091512
let alignment_ty = bx.type_i32();
15101513
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
15111514

15121515
// Truncate the mask vector to a vector of i1s:
1513-
let (mask, mask_ty) = {
1514-
let i1 = bx.type_i1();
1515-
let i1xn = bx.type_vector(i1, in_len);
1516-
(bx.trunc(args[2].immediate(), i1xn), i1xn)
1517-
};
1516+
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
1517+
let mask_ty = bx.type_vector(bx.type_i1(), in_len);
15181518

15191519
// Type of the vector of pointers:
15201520
let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
@@ -1790,8 +1790,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17901790
);
17911791

17921792
// The element type of the third argument must be a signed integer type of any width:
1793-
match element_ty2.kind() {
1794-
ty::Int(_) => (),
1793+
let mask_elem_bitwidth = match element_ty2.kind() {
1794+
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
17951795
_ => {
17961796
return_error!(InvalidMonomorphization::ThirdArgElementType {
17971797
span,
@@ -1800,18 +1800,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18001800
third_arg: arg_tys[2]
18011801
});
18021802
}
1803-
}
1803+
};
18041804

18051805
// Alignment of T, must be a constant integer value:
18061806
let alignment_ty = bx.type_i32();
18071807
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
18081808

18091809
// Truncate the mask vector to a vector of i1s:
1810-
let (mask, mask_ty) = {
1811-
let i1 = bx.type_i1();
1812-
let i1xn = bx.type_vector(i1, in_len);
1813-
(bx.trunc(args[2].immediate(), i1xn), i1xn)
1814-
};
1810+
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
1811+
let mask_ty = bx.type_vector(bx.type_i1(), in_len);
18151812

18161813
let ret_t = bx.type_void();
18171814

@@ -1949,8 +1946,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19491946
);
19501947
args[0].immediate()
19511948
} else {
1952-
match in_elem.kind() {
1953-
ty::Int(_) | ty::Uint(_) => {}
1949+
let bitwidth = match in_elem.kind() {
1950+
ty::Int(i) => {
1951+
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
1952+
}
1953+
ty::Uint(i) => {
1954+
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
1955+
}
19541956
_ => return_error!(InvalidMonomorphization::UnsupportedSymbol {
19551957
span,
19561958
name,
@@ -1959,12 +1961,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19591961
in_elem,
19601962
ret_ty
19611963
}),
1962-
}
1964+
};
19631965

1964-
// boolean reductions operate on vectors of i1s:
1965-
let i1 = bx.type_i1();
1966-
let i1xn = bx.type_vector(i1, in_len as u64);
1967-
bx.trunc(args[0].immediate(), i1xn)
1966+
vector_mask_to_bitmask(bx, args[0].immediate(), bitwidth, in_len as _)
19681967
};
19691968
return match in_elem.kind() {
19701969
ty::Int(_) | ty::Uint(_) => {

0 commit comments

Comments
 (0)