Skip to content

refactor: replace Range with a bounded implementation #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 25, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 74 additions & 130 deletions src/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
use crate::{internal::small_vec::SmallVec, version_set::VersionSet};
use std::ops::RangeBounds;
use std::{
cmp::Ordering,
fmt::{Debug, Display, Formatter},
ops::Bound::{self, Excluded, Included, Unbounded},
};
Expand Down Expand Up @@ -202,15 +201,40 @@ impl<V: Ord> Range<V> {
Excluded(v) => Excluded(v.clone().into()),
Unbounded => Unbounded,
};
match (start, end) {
(Included(a), Included(b)) if b < a => Self::empty(),
(Excluded(a), Excluded(b)) if b < a => Self::empty(),
(Included(a), Excluded(b)) if b <= a => Self::empty(),
(Excluded(a), Included(b)) if b <= a => Self::empty(),
(a, b) => Self {
segments: SmallVec::one((a, b)),
},
if valid_segment(&start, &end) {
Self {
segments: SmallVec::one((start, end)),
}
} else {
Self::empty()
}
}

fn check_invariants(self) -> Self {
if cfg!(debug_assertions) {
for (i, (s, e)) in self.segments.iter().enumerate() {
if matches!(s, Unbounded) && i != 0 {
panic!()
}
if matches!(e, Unbounded) && i != (self.segments.len() - 1) {
panic!()
}
}
for p in self.segments.as_slice().windows(2) {
match (&p[0].1, &p[1].0) {
(Included(l_end), Included(r_start)) => assert!(l_end < r_start),
(Included(l_end), Excluded(r_start)) => assert!(l_end < r_start),
(Excluded(l_end), Included(r_start)) => assert!(l_end < r_start),
(Excluded(l_end), Excluded(r_start)) => assert!(l_end <= r_start),
(_, Unbounded) => panic!(),
(Unbounded, _) => panic!(),
}
}
for (s, e) in self.segments.iter() {
assert!(valid_segment(s, e));
}
}
self
}
}

Expand All @@ -223,142 +247,62 @@ fn bound_as_ref<V>(bound: &Bound<V>) -> Bound<&V> {
}
}

fn valid_segment<T: PartialOrd>(start: &Bound<T>, end: &Bound<T>) -> bool {
match (start, end) {
(Included(s), Included(e)) => s <= e,
(Included(s), Excluded(e)) => s < e,
(Excluded(s), Included(e)) => s < e,
(Excluded(s), Excluded(e)) => s < e,
(Unbounded, _) | (_, Unbounded) => true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, there might be representations special cases where we have an invalid segment that cannot be checked by this function like (Unbounded, Excluded(0)) in the case where versions are u32. It depends if there are other ways to represent the left-most and right-most bounds with the version type.

What do you guys think of that? negligible?

}
}

impl<V: Ord + Clone> Range<V> {
/// Computes the union of this `Range` and another.
pub fn union(&self, other: &Self) -> Self {
self.complement()
.intersection(&other.complement())
.complement()
.check_invariants()
}

/// Computes the intersection of two sets of versions.
pub fn intersection(&self, other: &Self) -> Self {
Copy link
Member

@Eh2406 Eh2406 May 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent time today trying to write a "more obviously correct" intersection, with the same perf. I did not succeed. I did find this method helpful for catching corner cases:

    fn check_invariants(&self) {
        if cfg!(debug_assertions) {
            for (i, (s, e)) in self.segments.iter().enumerate() {
                if matches!(s, Unbounded) && i != 0 {
                    panic!()
                }
                if matches!(e, Unbounded) && i != (self.segments.len() - 1) {
                    panic!()
                }
            }
            for p in self.segments.as_slice().windows(2) {
                match (&p[0].1, &p[1].0) {
                    (Included(l_end), Included(r_start)) => assert!(l_end < r_start),
                    (Included(l_end), Excluded(r_start)) => assert!(l_end < r_start),
                    (Excluded(l_end), Included(r_start)) => assert!(l_end < r_start),
                    (Excluded(l_end), Excluded(r_start)) => assert!(l_end <= r_start),
                    (_, Unbounded) => panic!(),
                    (Unbounded, _) => panic!(),
                }
            }
            for (s, e) in self.segments.iter() {
                assert!(match (s, e) {
                    (Included(s), Included(e)) => s <= e,
                    (Included(s), Excluded(e)) => s < e,
                    (Excluded(s), Included(e)) => s < e,
                    (Excluded(s), Excluded(e)) => s < e,
                    (Unbounded, _) | (_, Unbounded) => true,
                });
            }
        }
    }

Perhaps we can add a call to it to the end of all methods that construct a range?

Copy link
Member

@Eh2406 Eh2406 May 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is as close as I came:
6ac1e06?diff=split

perf is slightly worse, but I find it more readable. what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do like that its a lot shorter and easier to follow, my code had a lot of cases. How much worse is the performance?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rerunning the benchmarks now I am not seeing significant differences between our two implementations! I guess I shouldn't try benchmarking at 2 o'clock in the morning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, Ill copy your implementation into the MR. Should I also add the invariant check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or can I simply merge your changes?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge/copy/rewrite, As you wish.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the check invariants function! Is the first for check not included in the second already btw?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cherry-picked your code (crediting you) and added the suggestion about next_if by @mpizenberg

let mut segments: SmallVec<Interval<V>> = SmallVec::empty();
let mut left_iter = self.segments.iter();
let mut right_iter = other.segments.iter();
let mut left = left_iter.next();
let mut right = right_iter.next();
while let (Some((left_lower, left_upper)), Some((right_lower, right_upper))) = (left, right)
let mut left_iter = self.segments.iter().peekable();
let mut right_iter = other.segments.iter().peekable();

while let (Some((left_start, left_end)), Some((right_start, right_end))) =
(left_iter.peek(), right_iter.peek())
{
// Check if the left range completely smaller than the right range.
if let (
Included(left_upper_version) | Excluded(left_upper_version),
Included(right_lower_version) | Excluded(right_lower_version),
) = (left_upper, right_lower)
{
match left_upper_version.cmp(right_lower_version) {
Ordering::Less => {
// Left range is disjoint from the right range.
left = left_iter.next();
continue;
}
Ordering::Equal => {
if !matches!((left_upper, right_lower), (Included(_), Included(_))) {
// Left and right are overlapping exactly, but one of the bounds is exclusive, therefor the ranges are disjoint
left = left_iter.next();
continue;
}
}
Ordering::Greater => {
// Left upper bound is greater than right lower bound, so the lower bound is the right lower bound
}
}
let start = match (left_start, right_start) {
(Included(l), Included(r)) => Included(std::cmp::max(l, r)),
(Excluded(l), Excluded(r)) => Excluded(std::cmp::max(l, r)),

(Included(i), Excluded(e)) | (Excluded(e), Included(i)) if i <= e => Excluded(e),
(Included(i), Excluded(e)) | (Excluded(e), Included(i)) if e < i => Included(i),
(s, Unbounded) | (Unbounded, s) => bound_as_ref(s),
_ => unreachable!(),
}
// Check if the right range completely smaller than the left range.
if let (
Included(left_lower_version) | Excluded(left_lower_version),
Included(right_upper_version) | Excluded(right_upper_version),
) = (left_lower, right_upper)
{
match right_upper_version.cmp(left_lower_version) {
Ordering::Less => {
// Right range is disjoint from the left range.
right = right_iter.next();
continue;
}
Ordering::Equal => {
if !matches!((right_upper, left_lower), (Included(_), Included(_))) {
// Left and right are overlapping exactly, but one of the bounds is exclusive, therefor the ranges are disjoint
right = right_iter.next();
continue;
}
}
Ordering::Greater => {
// Right upper bound is greater than left lower bound, so the lower bound is the left lower bound
}
}
.cloned();
let end = match (left_end, right_end) {
(Included(l), Included(r)) => Included(std::cmp::min(l, r)),
(Excluded(l), Excluded(r)) => Excluded(std::cmp::min(l, r)),

(Included(i), Excluded(e)) | (Excluded(e), Included(i)) if i >= e => Excluded(e),
(Included(i), Excluded(e)) | (Excluded(e), Included(i)) if e > i => Included(i),
(s, Unbounded) | (Unbounded, s) => bound_as_ref(s),
_ => unreachable!(),
}
.cloned();
left_iter.next_if(|(_, e)| e == &end);
right_iter.next_if(|(_, e)| e == &end);
if valid_segment(&start, &end) {
segments.push((start, end))
}

// At this point we know there is an overlap between the versions, find the lowest bound
let lower = match (left_lower, right_lower) {
(Unbounded, Included(_) | Excluded(_)) => right_lower.clone(),
(Included(_) | Excluded(_), Unbounded) => left_lower.clone(),
(Unbounded, Unbounded) => Unbounded,
(Included(l) | Excluded(l), Included(r) | Excluded(r)) => match l.cmp(r) {
Ordering::Less => right_lower.clone(),
Ordering::Equal => match (left_lower, right_lower) {
(Included(_), Excluded(v)) => Excluded(v.clone()),
(Excluded(_), Excluded(v)) => Excluded(v.clone()),
(Excluded(v), Included(_)) => Excluded(v.clone()),
(Included(_), Included(v)) => Included(v.clone()),
_ => unreachable!(),
},
Ordering::Greater => left_lower.clone(),
},
};

// At this point we know there is an overlap between the versions, find the lowest bound
let upper = match (left_upper, right_upper) {
(Unbounded, Included(_) | Excluded(_)) => {
right = right_iter.next();
right_upper.clone()
}
(Included(_) | Excluded(_), Unbounded) => {
left = left_iter.next();
left_upper.clone()
}
(Unbounded, Unbounded) => {
left = left_iter.next();
right = right_iter.next();
Unbounded
}
(Included(l) | Excluded(l), Included(r) | Excluded(r)) => match l.cmp(r) {
Ordering::Less => {
left = left_iter.next();
left_upper.clone()
}
Ordering::Equal => match (left_upper, right_upper) {
(Included(_), Excluded(v)) => {
right = right_iter.next();
Excluded(v.clone())
}
(Excluded(_), Excluded(v)) => {
left = left_iter.next();
right = right_iter.next();
Excluded(v.clone())
}
(Excluded(v), Included(_)) => {
left = left_iter.next();
Excluded(v.clone())
}
(Included(_), Included(v)) => {
left = left_iter.next();
right = right_iter.next();
Included(v.clone())
}
_ => unreachable!(),
},
Ordering::Greater => {
right = right_iter.next();
right_upper.clone()
}
},
};

segments.push((lower, upper));
}

Self { segments }
Self { segments }.check_invariants()
}
}

Expand Down Expand Up @@ -614,7 +558,7 @@ pub mod tests {

#[test]
fn from_range_bounds(range in any::<(Bound<u32>, Bound<u32>)>(), version in version_strat()) {
let rv: Range<u32> = Range::from_range_bounds(range);
let rv: Range<_> = Range::from_range_bounds(range);
assert_eq!(range.contains(&version), rv.contains(&version));
}

Expand Down