Skip to content

Commit e6b14ec

Browse files
committed
kk: replace panics with errors
1 parent 0b8bf5c commit e6b14ec

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

src/algorithms/kk.rs

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::Error;
12
use std::collections::BinaryHeap;
23
use std::ops::Sub;
34
use std::ops::SubAssign;
@@ -9,18 +10,14 @@ use num::Zero;
910
/// # Differences with the k-partitioning implementation
1011
///
1112
/// This function has better performance than [kk] called with `num_parts == 2`.
12-
fn kk_bipart<T>(partition: &mut [usize], weights: impl IntoIterator<Item = T>)
13+
fn kk_bipart<T>(partition: &mut [usize], weights: impl Iterator<Item = T>)
1314
where
1415
T: Ord + Sub<Output = T>,
1516
{
1617
let mut weights: BinaryHeap<(T, usize)> = weights
1718
.into_iter()
1819
.zip(0..) // Keep track of the weights' indicies
1920
.collect();
20-
assert_eq!(partition.len(), weights.len());
21-
if weights.is_empty() {
22-
return;
23-
}
2421

2522
// Core algorithm: find the imbalance of the partition.
2623
// "opposites" is built in this loop to backtrack the solution. It tracks weights that must end
@@ -54,30 +51,15 @@ where
5451
fn kk<T, I>(partition: &mut [usize], weights: I, num_parts: usize)
5552
where
5653
T: Zero + Ord + Sub<Output = T> + SubAssign + Copy,
57-
I: IntoIterator<Item = T>,
58-
<I as IntoIterator>::IntoIter: ExactSizeIterator,
54+
I: Iterator<Item = T> + ExactSizeIterator,
5955
{
60-
let weights = weights.into_iter();
61-
let num_weights = weights.len();
62-
assert_eq!(partition.len(), num_weights);
63-
if num_weights == 0 {
64-
return;
65-
}
66-
if num_parts < 2 {
67-
return;
68-
}
69-
if num_parts == 2 {
70-
// The bi-partitioning is a special case that can be handled faster than
71-
// the general case.
72-
return kk_bipart(partition, weights);
73-
}
74-
7556
// Initialize "m", a "k*num_weights" matrix whose first column is "weights".
57+
let weight_count = weights.len();
7658
let mut m: BinaryHeap<Vec<(T, usize)>> = weights
7759
.zip(0..)
7860
.map(|(w, id)| {
7961
let mut v: Vec<(T, usize)> = (0..num_parts)
80-
.map(|p| (T::zero(), num_weights * p + id))
62+
.map(|p| (T::zero(), weight_count * p + id))
8163
.collect();
8264
v[0].0 = w;
8365
v
@@ -88,7 +70,7 @@ where
8870
// largest weights in two different parts, the largest weight of each row is put into the same
8971
// part as the smallest one, and so on.
9072

91-
let mut opposites = Vec::with_capacity(num_weights);
73+
let mut opposites = Vec::with_capacity(weight_count);
9274
while 2 <= m.len() {
9375
let a = m.pop().unwrap();
9476
let b = m.pop().unwrap();
@@ -119,7 +101,7 @@ where
119101
// Backtracking. Same as the bi-partitioning case.
120102

121103
// parts = [ [m0i] for m0i in m[0] ]
122-
let mut parts: Vec<usize> = vec![0; num_parts * num_weights];
104+
let mut parts: Vec<usize> = vec![0; num_parts * weight_count];
123105
let imbalance = m.pop().unwrap(); // first and last element of "m".
124106
for (i, w) in imbalance.into_iter().enumerate() {
125107
// Put each remaining element in a different part.
@@ -160,14 +142,30 @@ where
160142
W::Item: Zero + Ord + Sub<Output = W::Item> + SubAssign + Copy,
161143
{
162144
type Metadata = ();
163-
type Error = std::convert::Infallible;
145+
type Error = Error;
164146

165147
fn partition(
166148
&mut self,
167149
part_ids: &mut [usize],
168150
weights: W,
169151
) -> Result<Self::Metadata, Self::Error> {
170-
kk(part_ids, weights, self.part_count);
152+
if self.part_count < 2 || part_ids.len() < 2 {
153+
return Ok(());
154+
}
155+
let weights = weights.into_iter();
156+
if weights.len() != part_ids.len() {
157+
return Err(Error::InputLenMismatch {
158+
expected: part_ids.len(),
159+
actual: weights.len(),
160+
});
161+
}
162+
if self.part_count == 2 {
163+
// The bi-partitioning is a special case that can be handled faster
164+
// than the general case.
165+
kk_bipart(part_ids, weights);
166+
} else {
167+
kk(part_ids, weights, self.part_count);
168+
}
171169
Ok(())
172170
}
173171
}

0 commit comments

Comments
 (0)