Skip to content

Commit 68e7162

Browse files
Merge pull request #72 from smartcorelib/lr_reg
feat: adds l2 regularization penalty to the Logistic Regression
2 parents 87d4e9a + 40a92ee commit 68e7162

File tree

1 file changed

+118
-16
lines changed

1 file changed

+118
-16
lines changed

src/linear/logistic_regression.rs

Lines changed: 118 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
5555
use std::cmp::Ordering;
5656
use std::fmt::Debug;
57-
use std::marker::PhantomData;
5857

5958
#[cfg(feature = "serde")]
6059
use serde::{Deserialize, Serialize};
@@ -79,9 +78,11 @@ pub enum LogisticRegressionSolverName {
7978
/// Logistic Regression parameters
8079
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8180
#[derive(Debug, Clone)]
82-
pub struct LogisticRegressionParameters {
81+
pub struct LogisticRegressionParameters<T: RealNumber> {
8382
/// Solver to use for estimation of regression coefficients.
8483
pub solver: LogisticRegressionSolverName,
84+
/// Regularization parameter.
85+
pub alpha: T,
8586
}
8687

8788
/// Logistic Regression
@@ -113,21 +114,27 @@ trait ObjectiveFunction<T: RealNumber, M: Matrix<T>> {
113114
struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
114115
x: &'a M,
115116
y: Vec<usize>,
116-
phantom: PhantomData<&'a T>,
117+
alpha: T,
117118
}
118119

119-
impl LogisticRegressionParameters {
120+
impl<T: RealNumber> LogisticRegressionParameters<T> {
120121
/// Solver to use for estimation of regression coefficients.
121122
pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
122123
self.solver = solver;
123124
self
124125
}
126+
/// Regularization parameter.
127+
pub fn with_alpha(mut self, alpha: T) -> Self {
128+
self.alpha = alpha;
129+
self
130+
}
125131
}
126132

127-
impl Default for LogisticRegressionParameters {
133+
impl<T: RealNumber> Default for LogisticRegressionParameters<T> {
128134
fn default() -> Self {
129135
LogisticRegressionParameters {
130136
solver: LogisticRegressionSolverName::LBFGS,
137+
alpha: T::zero(),
131138
}
132139
}
133140
}
@@ -156,13 +163,22 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
156163
{
157164
fn f(&self, w_bias: &M) -> T {
158165
let mut f = T::zero();
159-
let (n, _) = self.x.shape();
166+
let (n, p) = self.x.shape();
160167

161168
for i in 0..n {
162169
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
163170
f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx;
164171
}
165172

173+
if self.alpha > T::zero() {
174+
let mut w_squared = T::zero();
175+
for i in 0..p {
176+
let w = w_bias.get(0, i);
177+
w_squared += w * w;
178+
}
179+
f += T::half() * self.alpha * w_squared;
180+
}
181+
166182
f
167183
}
168184

@@ -180,14 +196,21 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
180196
}
181197
g.set(0, p, g.get(0, p) - dyi);
182198
}
199+
200+
if self.alpha > T::zero() {
201+
for i in 0..p {
202+
let w = w_bias.get(0, i);
203+
g.set(0, i, g.get(0, i) + self.alpha * w);
204+
}
205+
}
183206
}
184207
}
185208

186209
struct MultiClassObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
187210
x: &'a M,
188211
y: Vec<usize>,
189212
k: usize,
190-
phantom: PhantomData<&'a T>,
213+
alpha: T,
191214
}
192215

193216
impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
@@ -209,6 +232,17 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
209232
f -= prob.get(0, self.y[i]).ln();
210233
}
211234

235+
if self.alpha > T::zero() {
236+
let mut w_squared = T::zero();
237+
for i in 0..self.k {
238+
for j in 0..p {
239+
let wi = w_bias.get(0, i * (p + 1) + j);
240+
w_squared += wi * wi;
241+
}
242+
}
243+
f += T::half() * self.alpha * w_squared;
244+
}
245+
212246
f
213247
}
214248

@@ -239,16 +273,27 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
239273
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
240274
}
241275
}
276+
277+
if self.alpha > T::zero() {
278+
for i in 0..self.k {
279+
for j in 0..p {
280+
let pos = i * (p + 1);
281+
let wi = w.get(0, pos + j);
282+
g.set(0, pos + j, g.get(0, pos + j) + self.alpha * wi);
283+
}
284+
}
285+
}
242286
}
243287
}
244288

245-
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters>
289+
impl<T: RealNumber, M: Matrix<T>>
290+
SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters<T>>
246291
for LogisticRegression<T, M>
247292
{
248293
fn fit(
249294
x: &M,
250295
y: &M::RowVector,
251-
parameters: LogisticRegressionParameters,
296+
parameters: LogisticRegressionParameters<T>,
252297
) -> Result<Self, Failed> {
253298
LogisticRegression::fit(x, y, parameters)
254299
}
@@ -268,7 +313,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
268313
pub fn fit(
269314
x: &M,
270315
y: &M::RowVector,
271-
_parameters: LogisticRegressionParameters,
316+
parameters: LogisticRegressionParameters<T>,
272317
) -> Result<LogisticRegression<T, M>, Failed> {
273318
let y_m = M::from_row_vector(y.clone());
274319
let (x_nrows, num_attributes) = x.shape();
@@ -302,7 +347,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
302347
let objective = BinaryObjectiveFunction {
303348
x,
304349
y: yi,
305-
phantom: PhantomData,
350+
alpha: parameters.alpha,
306351
};
307352

308353
let result = LogisticRegression::minimize(x0, objective);
@@ -324,7 +369,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
324369
x,
325370
y: yi,
326371
k,
327-
phantom: PhantomData,
372+
alpha: parameters.alpha,
328373
};
329374

330375
let result = LogisticRegression::minimize(x0, objective);
@@ -431,9 +476,9 @@ mod tests {
431476

432477
let objective = MultiClassObjectiveFunction {
433478
x: &x,
434-
y,
479+
y: y.clone(),
435480
k: 3,
436-
phantom: PhantomData,
481+
alpha: 0.0,
437482
};
438483

439484
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
@@ -454,6 +499,24 @@ mod tests {
454499
]));
455500

456501
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
502+
503+
let objective_reg = MultiClassObjectiveFunction {
504+
x: &x,
505+
y: y.clone(),
506+
k: 3,
507+
alpha: 1.0,
508+
};
509+
510+
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[
511+
1., 2., 3., 4., 5., 6., 7., 8., 9.,
512+
]));
513+
assert!((f - 487.5052).abs() < 1e-4);
514+
515+
objective_reg.df(
516+
&mut g,
517+
&DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]),
518+
);
519+
assert!((g.get(0, 0).abs() - 32.0).abs() < 1e-4);
457520
}
458521

459522
#[test]
@@ -480,8 +543,8 @@ mod tests {
480543

481544
let objective = BinaryObjectiveFunction {
482545
x: &x,
483-
y,
484-
phantom: PhantomData,
546+
y: y.clone(),
547+
alpha: 0.0,
485548
};
486549

487550
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
@@ -496,6 +559,20 @@ mod tests {
496559
let f = objective.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
497560

498561
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
562+
563+
let objective_reg = BinaryObjectiveFunction {
564+
x: &x,
565+
y: y.clone(),
566+
alpha: 1.0,
567+
};
568+
569+
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
570+
assert!((f - 62.2699).abs() < 1e-4);
571+
572+
objective_reg.df(&mut g, &DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
573+
assert!((g.get(0, 0) - 27.0511).abs() < 1e-4);
574+
assert!((g.get(0, 1) - 12.239).abs() < 1e-4);
575+
assert!((g.get(0, 2) - 3.8693).abs() < 1e-4);
499576
}
500577

501578
#[test]
@@ -547,6 +624,15 @@ mod tests {
547624
let y_hat = lr.predict(&x).unwrap();
548625

549626
assert!(accuracy(&y_hat, &y) > 0.9);
627+
628+
let lr_reg = LogisticRegression::fit(
629+
&x,
630+
&y,
631+
LogisticRegressionParameters::default().with_alpha(10.0),
632+
)
633+
.unwrap();
634+
635+
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
550636
}
551637

552638
#[test]
@@ -561,6 +647,15 @@ mod tests {
561647
let y_hat = lr.predict(&x).unwrap();
562648

563649
assert!(accuracy(&y_hat, &y) > 0.9);
650+
651+
let lr_reg = LogisticRegression::fit(
652+
&x,
653+
&y,
654+
LogisticRegressionParameters::default().with_alpha(10.0),
655+
)
656+
.unwrap();
657+
658+
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
564659
}
565660

566661
#[test]
@@ -622,6 +717,12 @@ mod tests {
622717
];
623718

624719
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
720+
let lr_reg = LogisticRegression::fit(
721+
&x,
722+
&y,
723+
LogisticRegressionParameters::default().with_alpha(1.0),
724+
)
725+
.unwrap();
625726

626727
let y_hat = lr.predict(&x).unwrap();
627728

@@ -632,5 +733,6 @@ mod tests {
632733
.sum();
633734

634735
assert!(error <= 1.0);
736+
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
635737
}
636738
}

0 commit comments

Comments
 (0)