Skip to content

Commit cc26555

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
fix: fixes suggested by Clippy
1 parent c42fccd commit cc26555

File tree

1 file changed

+46
-44
lines changed

1 file changed

+46
-44
lines changed

src/linear/logistic_regression.rs

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
//!
5353
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
5454
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
55+
use std::cmp::Ordering;
5556
use std::fmt::Debug;
5657
use std::marker::PhantomData;
5758

@@ -232,51 +233,53 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
232233
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
233234
}
234235

235-
if k < 2 {
236-
Err(Failed::fit(&format!(
236+
match k.cmp(&2) {
237+
Ordering::Less => Err(Failed::fit(&format!(
237238
"incorrect number of classes: {}. Should be >= 2.",
238239
k
239-
)))
240-
} else if k == 2 {
241-
let x0 = M::zeros(1, num_attributes + 1);
242-
243-
let objective = BinaryObjectiveFunction {
244-
x: x,
245-
y: yi,
246-
phantom: PhantomData,
247-
};
248-
249-
let result = LogisticRegression::minimize(x0, objective);
250-
let weights = result.x;
251-
252-
Ok(LogisticRegression {
253-
coefficients: weights.slice(0..1, 0..num_attributes),
254-
intercept: weights.slice(0..1, num_attributes..num_attributes + 1),
255-
classes: classes,
256-
num_attributes: num_attributes,
257-
num_classes: k,
258-
})
259-
} else {
260-
let x0 = M::zeros(1, (num_attributes + 1) * k);
261-
262-
let objective = MultiClassObjectiveFunction {
263-
x: x,
264-
y: yi,
265-
k: k,
266-
phantom: PhantomData,
267-
};
268-
269-
let result = LogisticRegression::minimize(x0, objective);
270-
271-
let weights = result.x.reshape(k, num_attributes + 1);
272-
273-
Ok(LogisticRegression {
274-
coefficients: weights.slice(0..k, 0..num_attributes),
275-
intercept: weights.slice(0..k, num_attributes..num_attributes + 1),
276-
classes: classes,
277-
num_attributes: num_attributes,
278-
num_classes: k,
279-
})
240+
))),
241+
Ordering::Equal => {
242+
let x0 = M::zeros(1, num_attributes + 1);
243+
244+
let objective = BinaryObjectiveFunction {
245+
x: x,
246+
y: yi,
247+
phantom: PhantomData,
248+
};
249+
250+
let result = LogisticRegression::minimize(x0, objective);
251+
let weights = result.x;
252+
253+
Ok(LogisticRegression {
254+
coefficients: weights.slice(0..1, 0..num_attributes),
255+
intercept: weights.slice(0..1, num_attributes..num_attributes + 1),
256+
classes: classes,
257+
num_attributes: num_attributes,
258+
num_classes: k,
259+
})
260+
}
261+
Ordering::Greater => {
262+
let x0 = M::zeros(1, (num_attributes + 1) * k);
263+
264+
let objective = MultiClassObjectiveFunction {
265+
x: x,
266+
y: yi,
267+
k: k,
268+
phantom: PhantomData,
269+
};
270+
271+
let result = LogisticRegression::minimize(x0, objective);
272+
273+
let weights = result.x.reshape(k, num_attributes + 1);
274+
275+
Ok(LogisticRegression {
276+
coefficients: weights.slice(0..k, 0..num_attributes),
277+
intercept: weights.slice(0..k, num_attributes..num_attributes + 1),
278+
classes: classes,
279+
num_attributes: num_attributes,
280+
num_classes: k,
281+
})
282+
}
280283
}
281284
}
282285

@@ -286,7 +289,6 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
286289
let n = x.shape().0;
287290
let mut result = M::zeros(1, n);
288291
if self.num_classes == 2 {
289-
let (nrows, _) = x.shape();
290292
let y_hat: Vec<T> = x.matmul(&self.coefficients.transpose()).get_col_as_vec(0);
291293
let intercept = self.intercept.get(0, 0);
292294
for i in 0..n {

0 commit comments

Comments
 (0)