Skip to content

Commit c42fccd

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
fix: ridge regression, code refactoring
1 parent 7a4fe11 commit c42fccd

File tree

3 files changed

+63
-24
lines changed

3 files changed

+63
-24
lines changed

src/linear/linear_regression.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
154154
}
155155

156156
/// Get estimates regression coefficients
157-
pub fn coefficients(&self) -> M {
158-
self.coefficients.clone()
157+
pub fn coefficients(&self) -> &M {
158+
&self.coefficients
159159
}
160160

161161
/// Get estimate of intercept

src/linear/logistic_regression.rs

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ use crate::optimization::FunctionOrder;
6868
/// Logistic Regression
6969
#[derive(Serialize, Deserialize, Debug)]
7070
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
71-
weights: M,
71+
coefficients: M,
72+
intercept: M,
7273
classes: Vec<T>,
7374
num_attributes: usize,
7475
num_classes: usize,
@@ -109,7 +110,7 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
109110
}
110111
}
111112

112-
return self.weights == other.weights;
113+
return self.coefficients == other.coefficients && self.intercept == other.intercept;
113114
}
114115
}
115116
}
@@ -246,9 +247,11 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
246247
};
247248

248249
let result = LogisticRegression::minimize(x0, objective);
250+
let weights = result.x;
249251

250252
Ok(LogisticRegression {
251-
weights: result.x,
253+
coefficients: weights.slice(0..1, 0..num_attributes),
254+
intercept: weights.slice(0..1, num_attributes..num_attributes + 1),
252255
classes: classes,
253256
num_attributes: num_attributes,
254257
num_classes: k,
@@ -268,7 +271,8 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
268271
let weights = result.x.reshape(k, num_attributes + 1);
269272

270273
Ok(LogisticRegression {
271-
weights: weights,
274+
coefficients: weights.slice(0..k, 0..num_attributes),
275+
intercept: weights.slice(0..k, num_attributes..num_attributes + 1),
272276
classes: classes,
273277
num_attributes: num_attributes,
274278
num_classes: k,
@@ -283,21 +287,26 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
283287
let mut result = M::zeros(1, n);
284288
if self.num_classes == 2 {
285289
let (nrows, _) = x.shape();
286-
let x_and_bias = x.h_stack(&M::ones(nrows, 1));
287-
let y_hat: Vec<T> = x_and_bias
288-
.matmul(&self.weights.transpose())
289-
.get_col_as_vec(0);
290+
let y_hat: Vec<T> = x.matmul(&self.coefficients.transpose()).get_col_as_vec(0);
291+
let intercept = self.intercept.get(0, 0);
290292
for i in 0..n {
291293
result.set(
292294
0,
293295
i,
294-
self.classes[if y_hat[i].sigmoid() > T::half() { 1 } else { 0 }],
296+
self.classes[if (y_hat[i] + intercept).sigmoid() > T::half() {
297+
1
298+
} else {
299+
0
300+
}],
295301
);
296302
}
297303
} else {
298-
let (nrows, _) = x.shape();
299-
let x_and_bias = x.h_stack(&M::ones(nrows, 1));
300-
let y_hat = x_and_bias.matmul(&self.weights.transpose());
304+
let mut y_hat = x.matmul(&self.coefficients.transpose());
305+
for r in 0..n {
306+
for c in 0..self.num_classes {
307+
y_hat.set(r, c, y_hat.get(r, c) + self.intercept.get(c, 0));
308+
}
309+
}
301310
let class_idxs = y_hat.argmax();
302311
for i in 0..n {
303312
result.set(0, i, self.classes[class_idxs[i]]);
@@ -307,17 +316,13 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
307316
}
308317

309318
/// Get estimates regression coefficients
310-
pub fn coefficients(&self) -> M {
311-
self.weights
312-
.slice(0..self.num_classes, 0..self.num_attributes)
319+
pub fn coefficients(&self) -> &M {
320+
&self.coefficients
313321
}
314322

315323
/// Get estimate of intercept
316-
pub fn intercept(&self) -> M {
317-
self.weights.slice(
318-
0..self.num_classes,
319-
self.num_attributes..self.num_attributes + 1,
320-
)
324+
pub fn intercept(&self) -> &M {
325+
&self.intercept
321326
}
322327

323328
fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> {
@@ -336,7 +341,9 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
336341
#[cfg(test)]
337342
mod tests {
338343
use super::*;
344+
use crate::dataset::generator::make_blobs;
339345
use crate::linalg::naive::dense_matrix::*;
346+
use crate::metrics::accuracy;
340347

341348
#[test]
342349
fn multiclass_objective_f() {
@@ -466,6 +473,34 @@ mod tests {
466473
);
467474
}
468475

476+
#[test]
477+
fn lr_fit_predict_multiclass() {
478+
let blobs = make_blobs(15, 4, 3);
479+
480+
let x = DenseMatrix::from_vec(15, 4, &blobs.data);
481+
let y = blobs.target;
482+
483+
let lr = LogisticRegression::fit(&x, &y).unwrap();
484+
485+
let y_hat = lr.predict(&x).unwrap();
486+
487+
assert!(accuracy(&y_hat, &y) > 0.9);
488+
}
489+
490+
#[test]
491+
fn lr_fit_predict_binary() {
492+
let blobs = make_blobs(20, 4, 2);
493+
494+
let x = DenseMatrix::from_vec(20, 4, &blobs.data);
495+
let y = blobs.target;
496+
497+
let lr = LogisticRegression::fit(&x, &y).unwrap();
498+
499+
let y_hat = lr.predict(&x).unwrap();
500+
501+
assert!(accuracy(&y_hat, &y) > 0.9);
502+
}
503+
469504
#[test]
470505
fn serde() {
471506
let x = DenseMatrix::from_2d_array(&[

src/linear/ridge_regression.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
134134
)));
135135
}
136136

137+
if y.len() != n {
138+
return Err(Failed::fit(&format!("Number of rows in X should = len(y)")));
139+
}
140+
137141
let y_column = M::from_row_vector(y.clone()).transpose();
138142

139143
let (w, b) = if parameters.normalize {
@@ -216,8 +220,8 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
216220
}
217221

218222
/// Get estimates regression coefficients
219-
pub fn coefficients(&self) -> M {
220-
self.coefficients.clone()
223+
pub fn coefficients(&self) -> &M {
224+
&self.coefficients
221225
}
222226

223227
/// Get estimate of intercept

0 commit comments

Comments
 (0)