Skip to content

Commit 87d4e9a

Browse files
Merge pull request #71 from smartcorelib/log_regression_solvers
feat: adds a new parameter to the logistic regression: solver
2 parents 272aabc + bd5fbb6 commit 87d4e9a

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

src/linear/logistic_regression.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,21 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
6868
use crate::optimization::line_search::Backtracking;
6969
use crate::optimization::FunctionOrder;
7070

71+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
72+
#[derive(Debug, Clone)]
73+
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
74+
pub enum LogisticRegressionSolverName {
75+
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
76+
LBFGS,
77+
}
78+
7179
/// Logistic Regression parameters
7280
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7381
#[derive(Debug, Clone)]
74-
pub struct LogisticRegressionParameters {}
82+
pub struct LogisticRegressionParameters {
83+
/// Solver to use for estimation of regression coefficients.
84+
pub solver: LogisticRegressionSolverName,
85+
}
7586

7687
/// Logistic Regression
7788
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -105,9 +116,19 @@ struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
105116
phantom: PhantomData<&'a T>,
106117
}
107118

119+
impl LogisticRegressionParameters {
120+
/// Solver to use for estimation of regression coefficients.
121+
pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
122+
self.solver = solver;
123+
self
124+
}
125+
}
126+
108127
impl Default for LogisticRegressionParameters {
109128
fn default() -> Self {
110-
LogisticRegressionParameters {}
129+
LogisticRegressionParameters {
130+
solver: LogisticRegressionSolverName::LBFGS,
131+
}
111132
}
112133
}
113134

0 commit comments

Comments
 (0)