Skip to content

Commit e540587

Browse files
Use IntervalSet in InitMask rather than the custom bitset impl
1 parent 99695a3 commit e540587

File tree

1 file changed

+34
-285
lines changed

1 file changed

+34
-285
lines changed

compiler/rustc_middle/src/mir/interpret/allocation.rs

Lines changed: 34 additions & 285 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
//! The virtual memory representation of the MIR interpreter.
22
33
use std::borrow::Cow;
4-
use std::convert::{TryFrom, TryInto};
5-
use std::iter;
64
use std::ops::{Deref, Range};
75
use std::ptr;
86

97
use rustc_ast::Mutability;
108
use rustc_data_structures::sorted_map::SortedMap;
9+
use rustc_index::interval::IntervalSet;
1110
use rustc_span::DUMMY_SP;
1211
use rustc_target::abi::{Align, HasDataLayout, Size};
1312

@@ -567,323 +566,73 @@ impl<Tag: Copy, Extra> Allocation<Tag, Extra> {
567566
// Uninitialized byte tracking
568567
////////////////////////////////////////////////////////////////////////////////
569568

570-
type Block = u64;
571-
572569
/// A bitmask where each bit refers to the byte with the same index. If the bit is `true`, the byte
573570
/// is initialized. If it is `false` the byte is uninitialized.
574-
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, TyEncodable, TyDecodable)]
571+
#[derive(Clone, Debug, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
575572
#[derive(HashStable)]
576573
pub struct InitMask {
577-
blocks: Vec<Block>,
578-
len: Size,
574+
set: IntervalSet<usize>,
579575
}
580576

581-
impl InitMask {
582-
pub const BLOCK_SIZE: u64 = 64;
583-
584-
#[inline]
585-
fn bit_index(bits: Size) -> (usize, usize) {
586-
// BLOCK_SIZE is the number of bits that can fit in a `Block`.
587-
// Each bit in a `Block` represents the initialization state of one byte of an allocation,
588-
// so we use `.bytes()` here.
589-
let bits = bits.bytes();
590-
let a = bits / InitMask::BLOCK_SIZE;
591-
let b = bits % InitMask::BLOCK_SIZE;
592-
(usize::try_from(a).unwrap(), usize::try_from(b).unwrap())
577+
impl Ord for InitMask {
578+
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
579+
self.set
580+
.iter()
581+
.cmp(other.set.iter())
582+
.then(self.set.domain_size().cmp(&other.set.domain_size()))
593583
}
584+
}
594585

595-
#[inline]
596-
fn size_from_bit_index(block: impl TryInto<u64>, bit: impl TryInto<u64>) -> Size {
597-
let block = block.try_into().ok().unwrap();
598-
let bit = bit.try_into().ok().unwrap();
599-
Size::from_bytes(block * InitMask::BLOCK_SIZE + bit)
586+
impl PartialOrd for InitMask {
587+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
588+
Some(self.cmp(other))
600589
}
590+
}
601591

592+
impl InitMask {
602593
pub fn new(size: Size, state: bool) -> Self {
603-
let mut m = InitMask { blocks: vec![], len: Size::ZERO };
604-
m.grow(size, state);
605-
m
594+
let mut set = IntervalSet::new(size.bytes_usize());
595+
if state {
596+
set.insert_all();
597+
}
598+
InitMask { set }
606599
}
607600

608601
pub fn set_range(&mut self, start: Size, end: Size, new_state: bool) {
609-
let len = self.len;
610-
if end > len {
611-
self.grow(end - len, new_state);
612-
}
602+
self.set.ensure(end.bytes_usize() + 1);
613603
self.set_range_inbounds(start, end, new_state);
614604
}
615605

616606
pub fn set_range_inbounds(&mut self, start: Size, end: Size, new_state: bool) {
617-
let (blocka, bita) = Self::bit_index(start);
618-
let (blockb, bitb) = Self::bit_index(end);
619-
if blocka == blockb {
620-
// First set all bits except the first `bita`,
621-
// then unset the last `64 - bitb` bits.
622-
let range = if bitb == 0 {
623-
u64::MAX << bita
624-
} else {
625-
(u64::MAX << bita) & (u64::MAX >> (64 - bitb))
626-
};
627-
if new_state {
628-
self.blocks[blocka] |= range;
629-
} else {
630-
self.blocks[blocka] &= !range;
631-
}
632-
return;
633-
}
634-
// across block boundaries
607+
assert!(end.bytes_usize() <= self.set.domain_size());
635608
if new_state {
636-
// Set `bita..64` to `1`.
637-
self.blocks[blocka] |= u64::MAX << bita;
638-
// Set `0..bitb` to `1`.
639-
if bitb != 0 {
640-
self.blocks[blockb] |= u64::MAX >> (64 - bitb);
641-
}
642-
// Fill in all the other blocks (much faster than one bit at a time).
643-
for block in (blocka + 1)..blockb {
644-
self.blocks[block] = u64::MAX;
645-
}
609+
self.set.insert_range(start.bytes_usize()..end.bytes_usize());
646610
} else {
647-
// Set `bita..64` to `0`.
648-
self.blocks[blocka] &= !(u64::MAX << bita);
649-
// Set `0..bitb` to `0`.
650-
if bitb != 0 {
651-
self.blocks[blockb] &= !(u64::MAX >> (64 - bitb));
652-
}
653-
// Fill in all the other blocks (much faster than one bit at a time).
654-
for block in (blocka + 1)..blockb {
655-
self.blocks[block] = 0;
656-
}
611+
self.set.remove_range(start.bytes_usize()..end.bytes_usize());
657612
}
658613
}
659614

660615
#[inline]
661616
pub fn get(&self, i: Size) -> bool {
662-
let (block, bit) = Self::bit_index(i);
663-
(self.blocks[block] & (1 << bit)) != 0
617+
self.set.contains(i.bytes_usize())
664618
}
665619

666620
#[inline]
667621
pub fn set(&mut self, i: Size, new_state: bool) {
668-
let (block, bit) = Self::bit_index(i);
669-
self.set_bit(block, bit, new_state);
670-
}
671-
672-
#[inline]
673-
fn set_bit(&mut self, block: usize, bit: usize, new_state: bool) {
674622
if new_state {
675-
self.blocks[block] |= 1 << bit;
623+
self.set.insert(i.bytes_usize());
676624
} else {
677-
self.blocks[block] &= !(1 << bit);
625+
self.set.remove(i.bytes_usize());
678626
}
679627
}
680628

681-
pub fn grow(&mut self, amount: Size, new_state: bool) {
682-
if amount.bytes() == 0 {
683-
return;
684-
}
685-
let unused_trailing_bits =
686-
u64::try_from(self.blocks.len()).unwrap() * Self::BLOCK_SIZE - self.len.bytes();
687-
if amount.bytes() > unused_trailing_bits {
688-
let additional_blocks = amount.bytes() / Self::BLOCK_SIZE + 1;
689-
self.blocks.extend(
690-
// FIXME(oli-obk): optimize this by repeating `new_state as Block`.
691-
iter::repeat(0).take(usize::try_from(additional_blocks).unwrap()),
692-
);
693-
}
694-
let start = self.len;
695-
self.len += amount;
696-
self.set_range_inbounds(start, start + amount, new_state); // `Size` operation
697-
}
698-
699629
/// Returns the index of the first bit in `start..end` (end-exclusive) that is equal to is_init.
700630
fn find_bit(&self, start: Size, end: Size, is_init: bool) -> Option<Size> {
701-
/// A fast implementation of `find_bit`,
702-
/// which skips over an entire block at a time if it's all 0s (resp. 1s),
703-
/// and finds the first 1 (resp. 0) bit inside a block using `trailing_zeros` instead of a loop.
704-
///
705-
/// Note that all examples below are written with 8 (instead of 64) bit blocks for simplicity,
706-
/// and with the least significant bit (and lowest block) first:
707-
///
708-
/// 00000000|00000000
709-
/// ^ ^ ^ ^
710-
/// index: 0 7 8 15
711-
///
712-
/// Also, if not stated, assume that `is_init = true`, that is, we are searching for the first 1 bit.
713-
fn find_bit_fast(
714-
init_mask: &InitMask,
715-
start: Size,
716-
end: Size,
717-
is_init: bool,
718-
) -> Option<Size> {
719-
/// Search one block, returning the index of the first bit equal to `is_init`.
720-
fn search_block(
721-
bits: Block,
722-
block: usize,
723-
start_bit: usize,
724-
is_init: bool,
725-
) -> Option<Size> {
726-
// For the following examples, assume this function was called with:
727-
// bits = 0b00111011
728-
// start_bit = 3
729-
// is_init = false
730-
// Note that, for the examples in this function, the most significant bit is written first,
731-
// which is backwards compared to the comments in `find_bit`/`find_bit_fast`.
732-
733-
// Invert bits so we're always looking for the first set bit.
734-
// ! 0b00111011
735-
// bits = 0b11000100
736-
let bits = if is_init { bits } else { !bits };
737-
// Mask off unused start bits.
738-
// 0b11000100
739-
// & 0b11111000
740-
// bits = 0b11000000
741-
let bits = bits & (!0 << start_bit);
742-
// Find set bit, if any.
743-
// bit = trailing_zeros(0b11000000)
744-
// bit = 6
745-
if bits == 0 {
746-
None
747-
} else {
748-
let bit = bits.trailing_zeros();
749-
Some(InitMask::size_from_bit_index(block, bit))
750-
}
751-
}
752-
753-
if start >= end {
754-
return None;
755-
}
756-
757-
// Convert `start` and `end` to block indexes and bit indexes within each block.
758-
// We must convert `end` to an inclusive bound to handle block boundaries correctly.
759-
//
760-
// For example:
761-
//
762-
// (a) 00000000|00000000 (b) 00000000|
763-
// ^~~~~~~~~~~^ ^~~~~~~~~^
764-
// start end start end
765-
//
766-
// In both cases, the block index of `end` is 1.
767-
// But we do want to search block 1 in (a), and we don't in (b).
768-
//
769-
// We subtract 1 from both end positions to make them inclusive:
770-
//
771-
// (a) 00000000|00000000 (b) 00000000|
772-
// ^~~~~~~~~~^ ^~~~~~~^
773-
// start end_inclusive start end_inclusive
774-
//
775-
// For (a), the block index of `end_inclusive` is 1, and for (b), it's 0.
776-
// This provides the desired behavior of searching blocks 0 and 1 for (a),
777-
// and searching only block 0 for (b).
778-
// There is no concern of overflows since we checked for `start >= end` above.
779-
let (start_block, start_bit) = InitMask::bit_index(start);
780-
let end_inclusive = Size::from_bytes(end.bytes() - 1);
781-
let (end_block_inclusive, _) = InitMask::bit_index(end_inclusive);
782-
783-
// Handle first block: need to skip `start_bit` bits.
784-
//
785-
// We need to handle the first block separately,
786-
// because there may be bits earlier in the block that should be ignored,
787-
// such as the bit marked (1) in this example:
788-
//
789-
// (1)
790-
// -|------
791-
// (c) 01000000|00000000|00000001
792-
// ^~~~~~~~~~~~~~~~~~^
793-
// start end
794-
if let Some(i) =
795-
search_block(init_mask.blocks[start_block], start_block, start_bit, is_init)
796-
{
797-
// If the range is less than a block, we may find a matching bit after `end`.
798-
//
799-
// For example, we shouldn't successfully find bit (2), because it's after `end`:
800-
//
801-
// (2)
802-
// -------|
803-
// (d) 00000001|00000000|00000001
804-
// ^~~~~^
805-
// start end
806-
//
807-
// An alternative would be to mask off end bits in the same way as we do for start bits,
808-
// but performing this check afterwards is faster and simpler to implement.
809-
if i < end {
810-
return Some(i);
811-
} else {
812-
return None;
813-
}
814-
}
815-
816-
// Handle remaining blocks.
817-
//
818-
// We can skip over an entire block at once if it's all 0s (resp. 1s).
819-
// The block marked (3) in this example is the first block that will be handled by this loop,
820-
// and it will be skipped for that reason:
821-
//
822-
// (3)
823-
// --------
824-
// (e) 01000000|00000000|00000001
825-
// ^~~~~~~~~~~~~~~~~~^
826-
// start end
827-
if start_block < end_block_inclusive {
828-
// This loop is written in a specific way for performance.
829-
// Notably: `..end_block_inclusive + 1` is used for an inclusive range instead of `..=end_block_inclusive`,
830-
// and `.zip(start_block + 1..)` is used to track the index instead of `.enumerate().skip().take()`,
831-
// because both alternatives result in significantly worse codegen.
832-
// `end_block_inclusive + 1` is guaranteed not to wrap, because `end_block_inclusive <= end / BLOCK_SIZE`,
833-
// and `BLOCK_SIZE` (the number of bits per block) will always be at least 8 (1 byte).
834-
for (&bits, block) in init_mask.blocks[start_block + 1..end_block_inclusive + 1]
835-
.iter()
836-
.zip(start_block + 1..)
837-
{
838-
if let Some(i) = search_block(bits, block, 0, is_init) {
839-
// If this is the last block, we may find a matching bit after `end`.
840-
//
841-
// For example, we shouldn't successfully find bit (4), because it's after `end`:
842-
//
843-
// (4)
844-
// -------|
845-
// (f) 00000001|00000000|00000001
846-
// ^~~~~~~~~~~~~~~~~~^
847-
// start end
848-
//
849-
// As above with example (d), we could handle the end block separately and mask off end bits,
850-
// but unconditionally searching an entire block at once and performing this check afterwards
851-
// is faster and much simpler to implement.
852-
if i < end {
853-
return Some(i);
854-
} else {
855-
return None;
856-
}
857-
}
858-
}
859-
}
860-
861-
None
862-
}
863-
864-
#[cfg_attr(not(debug_assertions), allow(dead_code))]
865-
fn find_bit_slow(
866-
init_mask: &InitMask,
867-
start: Size,
868-
end: Size,
869-
is_init: bool,
870-
) -> Option<Size> {
871-
(start..end).find(|&i| init_mask.get(i) == is_init)
631+
if is_init {
632+
self.set.first_set_in(start.bytes_usize()..end.bytes_usize()).map(Size::from_bytes)
633+
} else {
634+
self.set.first_gap_in(start.bytes_usize()..end.bytes_usize()).map(Size::from_bytes)
872635
}
873-
874-
let result = find_bit_fast(self, start, end, is_init);
875-
876-
debug_assert_eq!(
877-
result,
878-
find_bit_slow(self, start, end, is_init),
879-
"optimized implementation of find_bit is wrong for start={:?} end={:?} is_init={} init_mask={:#?}",
880-
start,
881-
end,
882-
is_init,
883-
self
884-
);
885-
886-
result
887636
}
888637
}
889638

@@ -918,8 +667,8 @@ impl InitMask {
918667
/// indexes for the first contiguous span of the uninitialized access.
919668
#[inline]
920669
pub fn is_range_initialized(&self, start: Size, end: Size) -> Result<(), Range<Size>> {
921-
if end > self.len {
922-
return Err(self.len..end);
670+
if end.bytes_usize() > self.set.domain_size() {
671+
return Err(Size::from_bytes(self.set.domain_size())..end);
923672
}
924673

925674
let uninit_start = self.find_bit(start, end, false);
@@ -943,7 +692,7 @@ impl InitMask {
943692
/// - Chunks alternate between [`InitChunk::Init`] and [`InitChunk::Uninit`].
944693
#[inline]
945694
pub fn range_as_init_chunks(&self, start: Size, end: Size) -> InitChunkIter<'_> {
946-
assert!(end <= self.len);
695+
assert!(end.bytes_usize() <= self.set.domain_size());
947696

948697
let is_init = if start < end {
949698
self.get(start)

0 commit comments

Comments
 (0)