|
52 | 52 | //!
|
53 | 53 | //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
54 | 54 | //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
| 55 | +use std::cmp::Ordering; |
55 | 56 | use std::fmt::Debug;
|
56 | 57 | use std::marker::PhantomData;
|
57 | 58 |
|
@@ -231,48 +232,50 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
231 | 232 | yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
232 | 233 | }
|
233 | 234 |
|
234 |
| - if k < 2 { |
235 |
| - Err(Failed::fit(&format!( |
| 235 | + match k.cmp(&2) { |
| 236 | + Ordering::Less => Err(Failed::fit(&format!( |
236 | 237 | "incorrect number of classes: {}. Should be >= 2.",
|
237 | 238 | k
|
238 |
| - ))) |
239 |
| - } else if k == 2 { |
240 |
| - let x0 = M::zeros(1, num_attributes + 1); |
241 |
| - |
242 |
| - let objective = BinaryObjectiveFunction { |
243 |
| - x, |
244 |
| - y: yi, |
245 |
| - phantom: PhantomData, |
246 |
| - }; |
247 |
| - |
248 |
| - let result = LogisticRegression::minimize(x0, objective); |
249 |
| - |
250 |
| - Ok(LogisticRegression { |
251 |
| - weights: result.x, |
252 |
| - classes, |
253 |
| - num_attributes, |
254 |
| - num_classes: k, |
255 |
| - }) |
256 |
| - } else { |
257 |
| - let x0 = M::zeros(1, (num_attributes + 1) * k); |
258 |
| - |
259 |
| - let objective = MultiClassObjectiveFunction { |
260 |
| - x, |
261 |
| - y: yi, |
262 |
| - k, |
263 |
| - phantom: PhantomData, |
264 |
| - }; |
265 |
| - |
266 |
| - let result = LogisticRegression::minimize(x0, objective); |
267 |
| - |
268 |
| - let weights = result.x.reshape(k, num_attributes + 1); |
269 |
| - |
270 |
| - Ok(LogisticRegression { |
271 |
| - weights, |
272 |
| - classes, |
273 |
| - num_attributes, |
274 |
| - num_classes: k, |
275 |
| - }) |
| 239 | + ))), |
| 240 | + Ordering::Greater => { |
| 241 | + let x0 = M::zeros(1, (num_attributes + 1) * k); |
| 242 | + |
| 243 | + let objective = MultiClassObjectiveFunction { |
| 244 | + x, |
| 245 | + y: yi, |
| 246 | + k, |
| 247 | + phantom: PhantomData, |
| 248 | + }; |
| 249 | + |
| 250 | + let result = LogisticRegression::minimize(x0, objective); |
| 251 | + |
| 252 | + let weights = result.x.reshape(k, num_attributes + 1); |
| 253 | + |
| 254 | + Ok(LogisticRegression { |
| 255 | + weights, |
| 256 | + classes, |
| 257 | + num_attributes, |
| 258 | + num_classes: k, |
| 259 | + }) |
| 260 | + } |
| 261 | + Ordering::Equal => { |
| 262 | + let x0 = M::zeros(1, num_attributes + 1); |
| 263 | + |
| 264 | + let objective = BinaryObjectiveFunction { |
| 265 | + x, |
| 266 | + y: yi, |
| 267 | + phantom: PhantomData, |
| 268 | + }; |
| 269 | + |
| 270 | + let result = LogisticRegression::minimize(x0, objective); |
| 271 | + |
| 272 | + Ok(LogisticRegression { |
| 273 | + weights: result.x, |
| 274 | + classes, |
| 275 | + num_attributes, |
| 276 | + num_classes: k, |
| 277 | + }) |
| 278 | + } |
276 | 279 | }
|
277 | 280 | }
|
278 | 281 |
|
|
0 commit comments