1
+ use super :: Error ;
1
2
use std:: collections:: BinaryHeap ;
2
3
use std:: ops:: Sub ;
3
4
use std:: ops:: SubAssign ;
@@ -9,18 +10,14 @@ use num::Zero;
9
10
/// # Differences with the k-partitioning implementation
10
11
///
11
12
/// This function has better performance than [kk] called with `num_parts == 2`.
12
- fn kk_bipart < T > ( partition : & mut [ usize ] , weights : impl IntoIterator < Item = T > )
13
+ fn kk_bipart < T > ( partition : & mut [ usize ] , weights : impl Iterator < Item = T > )
13
14
where
14
15
T : Ord + Sub < Output = T > ,
15
16
{
16
17
let mut weights: BinaryHeap < ( T , usize ) > = weights
17
18
. into_iter ( )
18
19
. zip ( 0 ..) // Keep track of the weights' indicies
19
20
. collect ( ) ;
20
- assert_eq ! ( partition. len( ) , weights. len( ) ) ;
21
- if weights. is_empty ( ) {
22
- return ;
23
- }
24
21
25
22
// Core algorithm: find the imbalance of the partition.
26
23
// "opposites" is built in this loop to backtrack the solution. It tracks weights that must end
@@ -54,30 +51,15 @@ where
54
51
fn kk < T , I > ( partition : & mut [ usize ] , weights : I , num_parts : usize )
55
52
where
56
53
T : Zero + Ord + Sub < Output = T > + SubAssign + Copy ,
57
- I : IntoIterator < Item = T > ,
58
- <I as IntoIterator >:: IntoIter : ExactSizeIterator ,
54
+ I : Iterator < Item = T > + ExactSizeIterator ,
59
55
{
60
- let weights = weights. into_iter ( ) ;
61
- let num_weights = weights. len ( ) ;
62
- assert_eq ! ( partition. len( ) , num_weights) ;
63
- if num_weights == 0 {
64
- return ;
65
- }
66
- if num_parts < 2 {
67
- return ;
68
- }
69
- if num_parts == 2 {
70
- // The bi-partitioning is a special case that can be handled faster than
71
- // the general case.
72
- return kk_bipart ( partition, weights) ;
73
- }
74
-
75
56
// Initialize "m", a "k*num_weights" matrix whose first column is "weights".
57
+ let weight_count = weights. len ( ) ;
76
58
let mut m: BinaryHeap < Vec < ( T , usize ) > > = weights
77
59
. zip ( 0 ..)
78
60
. map ( |( w, id) | {
79
61
let mut v: Vec < ( T , usize ) > = ( 0 ..num_parts)
80
- . map ( |p| ( T :: zero ( ) , num_weights * p + id) )
62
+ . map ( |p| ( T :: zero ( ) , weight_count * p + id) )
81
63
. collect ( ) ;
82
64
v[ 0 ] . 0 = w;
83
65
v
88
70
// largest weights in two different parts, the largest weight of each row is put into the same
89
71
// part as the smallest one, and so on.
90
72
91
- let mut opposites = Vec :: with_capacity ( num_weights ) ;
73
+ let mut opposites = Vec :: with_capacity ( weight_count ) ;
92
74
while 2 <= m. len ( ) {
93
75
let a = m. pop ( ) . unwrap ( ) ;
94
76
let b = m. pop ( ) . unwrap ( ) ;
@@ -119,7 +101,7 @@ where
119
101
// Backtracking. Same as the bi-partitioning case.
120
102
121
103
// parts = [ [m0i] for m0i in m[0] ]
122
- let mut parts: Vec < usize > = vec ! [ 0 ; num_parts * num_weights ] ;
104
+ let mut parts: Vec < usize > = vec ! [ 0 ; num_parts * weight_count ] ;
123
105
let imbalance = m. pop ( ) . unwrap ( ) ; // first and last element of "m".
124
106
for ( i, w) in imbalance. into_iter ( ) . enumerate ( ) {
125
107
// Put each remaining element in a different part.
@@ -160,14 +142,30 @@ where
160
142
W :: Item : Zero + Ord + Sub < Output = W :: Item > + SubAssign + Copy ,
161
143
{
162
144
type Metadata = ( ) ;
163
- type Error = std :: convert :: Infallible ;
145
+ type Error = Error ;
164
146
165
147
fn partition (
166
148
& mut self ,
167
149
part_ids : & mut [ usize ] ,
168
150
weights : W ,
169
151
) -> Result < Self :: Metadata , Self :: Error > {
170
- kk ( part_ids, weights, self . part_count ) ;
152
+ if self . part_count < 2 || part_ids. len ( ) < 2 {
153
+ return Ok ( ( ) ) ;
154
+ }
155
+ let weights = weights. into_iter ( ) ;
156
+ if weights. len ( ) != part_ids. len ( ) {
157
+ return Err ( Error :: InputLenMismatch {
158
+ expected : part_ids. len ( ) ,
159
+ actual : weights. len ( ) ,
160
+ } ) ;
161
+ }
162
+ if self . part_count == 2 {
163
+ // The bi-partitioning is a special case that can be handled faster
164
+ // than the general case.
165
+ kk_bipart ( part_ids, weights) ;
166
+ } else {
167
+ kk ( part_ids, weights, self . part_count ) ;
168
+ }
171
169
Ok ( ( ) )
172
170
}
173
171
}
0 commit comments