@@ -247,6 +247,7 @@ use super::SpecExtend;
247
247
/// [peek]: BinaryHeap::peek
248
248
/// [peek\_mut]: BinaryHeap::peek_mut
249
249
#[ stable( feature = "rust1" , since = "1.0.0" ) ]
250
+ #[ cfg_attr( not( test) , rustc_diagnostic_item = "BinaryHeap" ) ]
250
251
pub struct BinaryHeap < T > {
251
252
data : Vec < T > ,
252
253
}
@@ -275,7 +276,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
275
276
impl < T : Ord > Drop for PeekMut < ' _ , T > {
276
277
fn drop ( & mut self ) {
277
278
if self . sift {
278
- self . heap . sift_down ( 0 ) ;
279
+ // SAFETY: PeekMut is only instantiated for non-empty heaps.
280
+ unsafe { self . heap . sift_down ( 0 ) } ;
279
281
}
280
282
}
281
283
}
@@ -431,7 +433,8 @@ impl<T: Ord> BinaryHeap<T> {
431
433
self . data . pop ( ) . map ( |mut item| {
432
434
if !self . is_empty ( ) {
433
435
swap ( & mut item, & mut self . data [ 0 ] ) ;
434
- self . sift_down_to_bottom ( 0 ) ;
436
+ // SAFETY: !self.is_empty() means that self.len() > 0
437
+ unsafe { self . sift_down_to_bottom ( 0 ) } ;
435
438
}
436
439
item
437
440
} )
@@ -473,7 +476,9 @@ impl<T: Ord> BinaryHeap<T> {
473
476
pub fn push ( & mut self , item : T ) {
474
477
let old_len = self . len ( ) ;
475
478
self . data . push ( item) ;
476
- self . sift_up ( 0 , old_len) ;
479
+ // SAFETY: Since we pushed a new item it means that
480
+ // old_len = self.len() - 1 < self.len()
481
+ unsafe { self . sift_up ( 0 , old_len) } ;
477
482
}
478
483
479
484
/// Consumes the `BinaryHeap` and returns a vector in sorted
@@ -506,7 +511,10 @@ impl<T: Ord> BinaryHeap<T> {
506
511
let ptr = self . data . as_mut_ptr ( ) ;
507
512
ptr:: swap ( ptr, ptr. add ( end) ) ;
508
513
}
509
- self . sift_down_range ( 0 , end) ;
514
+ // SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
515
+ // 0 < 1 <= end <= self.len() - 1 < self.len()
516
+ // Which means 0 < end and end < self.len().
517
+ unsafe { self . sift_down_range ( 0 , end) } ;
510
518
}
511
519
self . into_vec ( )
512
520
}
@@ -519,78 +527,139 @@ impl<T: Ord> BinaryHeap<T> {
519
527
// the hole is filled back at the end of its scope, even on panic.
520
528
// Using a hole reduces the constant factor compared to using swaps,
521
529
// which involves twice as many moves.
522
- fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
523
- unsafe {
524
- // Take out the value at `pos` and create a hole.
525
- let mut hole = Hole :: new ( & mut self . data , pos) ;
526
-
527
- while hole. pos ( ) > start {
528
- let parent = ( hole. pos ( ) - 1 ) / 2 ;
529
- if hole. element ( ) <= hole. get ( parent) {
530
- break ;
531
- }
532
- hole. move_to ( parent) ;
530
+
531
+ /// # Safety
532
+ ///
533
+ /// The caller must guarantee that `pos < self.len()`.
534
+ unsafe fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
535
+ // Take out the value at `pos` and create a hole.
536
+ // SAFETY: The caller guarantees that pos < self.len()
537
+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
538
+
539
+ while hole. pos ( ) > start {
540
+ let parent = ( hole. pos ( ) - 1 ) / 2 ;
541
+
542
+ // SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
543
+ // and so hole.pos() - 1 can't underflow.
544
+ // This guarantees that parent < hole.pos() so
545
+ // it's a valid index and also != hole.pos().
546
+ if hole. element ( ) <= unsafe { hole. get ( parent) } {
547
+ break ;
533
548
}
534
- hole. pos ( )
549
+
550
+ // SAFETY: Same as above
551
+ unsafe { hole. move_to ( parent) } ;
535
552
}
553
+
554
+ hole. pos ( )
536
555
}
537
556
538
557
/// Take an element at `pos` and move it down the heap,
539
558
/// while its children are larger.
540
- fn sift_down_range ( & mut self , pos : usize , end : usize ) {
541
- unsafe {
542
- let mut hole = Hole :: new ( & mut self . data , pos) ;
543
- let mut child = 2 * pos + 1 ;
544
- while child < end - 1 {
545
- // compare with the greater of the two children
546
- child += ( hole. get ( child) <= hole. get ( child + 1 ) ) as usize ;
547
- // if we are already in order, stop.
548
- if hole. element ( ) >= hole. get ( child) {
549
- return ;
550
- }
551
- hole. move_to ( child) ;
552
- child = 2 * hole. pos ( ) + 1 ;
553
- }
554
- if child == end - 1 && hole. element ( ) < hole. get ( child) {
555
- hole. move_to ( child) ;
559
+ ///
560
+ /// # Safety
561
+ ///
562
+ /// The caller must guarantee that `pos < end <= self.len()`.
563
+ unsafe fn sift_down_range ( & mut self , pos : usize , end : usize ) {
564
+ // SAFETY: The caller guarantees that pos < end <= self.len().
565
+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
566
+ let mut child = 2 * hole. pos ( ) + 1 ;
567
+
568
+ // Loop invariant: child == 2 * hole.pos() + 1.
569
+ while child <= end. saturating_sub ( 2 ) {
570
+ // compare with the greater of the two children
571
+ // SAFETY: child < end - 1 < self.len() and
572
+ // child + 1 < end <= self.len(), so they're valid indexes.
573
+ // child == 2 * hole.pos() + 1 != hole.pos() and
574
+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
575
+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
576
+ // if T is a ZST
577
+ child += unsafe { hole. get ( child) <= hole. get ( child + 1 ) } as usize ;
578
+
579
+ // if we are already in order, stop.
580
+ // SAFETY: child is now either the old child or the old child+1
581
+ // We already proven that both are < self.len() and != hole.pos()
582
+ if hole. element ( ) >= unsafe { hole. get ( child) } {
583
+ return ;
556
584
}
585
+
586
+ // SAFETY: same as above.
587
+ unsafe { hole. move_to ( child) } ;
588
+ child = 2 * hole. pos ( ) + 1 ;
589
+ }
590
+
591
+ // SAFETY: && short circuit, which means that in the
592
+ // second condition it's already true that child == end - 1 < self.len().
593
+ if child == end - 1 && hole. element ( ) < unsafe { hole. get ( child) } {
594
+ // SAFETY: child is already proven to be a valid index and
595
+ // child == 2 * hole.pos() + 1 != hole.pos().
596
+ unsafe { hole. move_to ( child) } ;
557
597
}
558
598
}
559
599
560
- fn sift_down ( & mut self , pos : usize ) {
600
+ /// # Safety
601
+ ///
602
+ /// The caller must guarantee that `pos < self.len()`.
603
+ unsafe fn sift_down ( & mut self , pos : usize ) {
561
604
let len = self . len ( ) ;
562
- self . sift_down_range ( pos, len) ;
605
+ // SAFETY: pos < len is guaranteed by the caller and
606
+ // obviously len = self.len() <= self.len().
607
+ unsafe { self . sift_down_range ( pos, len) } ;
563
608
}
564
609
565
610
/// Take an element at `pos` and move it all the way down the heap,
566
611
/// then sift it up to its position.
567
612
///
568
613
/// Note: This is faster when the element is known to be large / should
569
614
/// be closer to the bottom.
570
- fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
615
+ ///
616
+ /// # Safety
617
+ ///
618
+ /// The caller must guarantee that `pos < self.len()`.
619
+ unsafe fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
571
620
let end = self . len ( ) ;
572
621
let start = pos;
573
- unsafe {
574
- let mut hole = Hole :: new ( & mut self . data , pos) ;
575
- let mut child = 2 * pos + 1 ;
576
- while child < end - 1 {
577
- child += ( hole. get ( child) <= hole. get ( child + 1 ) ) as usize ;
578
- hole. move_to ( child) ;
579
- child = 2 * hole. pos ( ) + 1 ;
580
- }
581
- if child == end - 1 {
582
- hole. move_to ( child) ;
583
- }
584
- pos = hole. pos ;
622
+
623
+ // SAFETY: The caller guarantees that pos < self.len().
624
+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
625
+ let mut child = 2 * hole. pos ( ) + 1 ;
626
+
627
+ // Loop invariant: child == 2 * hole.pos() + 1.
628
+ while child <= end. saturating_sub ( 2 ) {
629
+ // SAFETY: child < end - 1 < self.len() and
630
+ // child + 1 < end <= self.len(), so they're valid indexes.
631
+ // child == 2 * hole.pos() + 1 != hole.pos() and
632
+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
633
+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
634
+ // if T is a ZST
635
+ child += unsafe { hole. get ( child) <= hole. get ( child + 1 ) } as usize ;
636
+
637
+ // SAFETY: Same as above
638
+ unsafe { hole. move_to ( child) } ;
639
+ child = 2 * hole. pos ( ) + 1 ;
585
640
}
586
- self . sift_up ( start, pos) ;
641
+
642
+ if child == end - 1 {
643
+ // SAFETY: child == end - 1 < self.len(), so it's a valid index
644
+ // and child == 2 * hole.pos() + 1 != hole.pos().
645
+ unsafe { hole. move_to ( child) } ;
646
+ }
647
+ pos = hole. pos ( ) ;
648
+ drop ( hole) ;
649
+
650
+ // SAFETY: pos is the position in the hole and was already proven
651
+ // to be a valid index.
652
+ unsafe { self . sift_up ( start, pos) } ;
587
653
}
588
654
589
655
fn rebuild ( & mut self ) {
590
656
let mut n = self . len ( ) / 2 ;
591
657
while n > 0 {
592
658
n -= 1 ;
593
- self . sift_down ( n) ;
659
+ // SAFETY: n starts from self.len() / 2 and goes down to 0.
660
+ // The only case when !(n < self.len()) is if
661
+ // self.len() == 0, but it's ruled out by the loop condition.
662
+ unsafe { self . sift_down ( n) } ;
594
663
}
595
664
}
596
665
0 commit comments