Skip to content

Commit ca3a3a1

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
fix: ridge regression, post-review changes
1 parent 83048db commit ca3a3a1

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

src/linalg/mod.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,10 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
168168

169169
/// Computes the arithmetic mean.
170170
fn mean(&self) -> T {
171-
let n = self.len();
172-
let mut mean = T::zero();
173-
174-
for i in 0..n {
175-
mean += self.get(i);
176-
}
177-
mean / T::from_usize(n).unwrap()
171+
self.sum() / T::from_usize(self.len()).unwrap()
178172
}
179-
/// Computes the standard deviation.
180-
fn std(&self) -> T {
173+
/// Computes variance.
174+
fn var(&self) -> T {
181175
let n = self.len();
182176

183177
let mut mu = T::zero();
@@ -189,7 +183,11 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
189183
sum += xi * xi;
190184
}
191185
mu /= div;
192-
(sum / div - mu * mu).sqrt()
186+
sum / div - mu * mu
187+
}
188+
/// Computes the standard deviation.
189+
fn std(&self) -> T {
190+
self.var().sqrt()
193191
}
194192
}
195193

@@ -592,4 +590,11 @@ mod tests {
592590

593591
assert!((m.std() - 0.81f64).abs() < 1e-2);
594592
}
593+
594+
#[test]
595+
fn var() {
596+
let m = vec![1., 2., 3., 4.];
597+
598+
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
599+
}
595600
}

src/linalg/stats.rs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
3535
x
3636
}
3737

38-
/// Computes the standard deviation along the specified axis.
39-
fn std(&self, axis: u8) -> Vec<T> {
38+
/// Computes variance along the specified axis.
39+
fn var(&self, axis: u8) -> Vec<T> {
4040
let (n, m) = match axis {
4141
0 => {
4242
let (n, m) = self.shape();
@@ -61,7 +61,24 @@ pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
6161
sum += a * a;
6262
}
6363
mu /= div;
64-
x[i] = (sum / div - mu * mu).sqrt();
64+
x[i] = sum / div - mu * mu;
65+
}
66+
67+
x
68+
}
69+
70+
/// Computes the standard deviation along the specified axis.
71+
fn std(&self, axis: u8) -> Vec<T> {
72+
73+
let mut x = self.var(axis);
74+
75+
let n = match axis {
76+
0 => self.shape().1,
77+
_ => self.shape().0,
78+
};
79+
80+
for i in 0..n {
81+
x[i] = x[i].sqrt();
6582
}
6683

6784
x
@@ -122,6 +139,19 @@ mod tests {
122139
assert!(m.std(1).approximate_eq(&expected_1, 1e-2));
123140
}
124141

142+
#[test]
143+
fn var() {
144+
let m = DenseMatrix::from_2d_array(&[
145+
&[1., 2., 3., 4.],
146+
&[5., 6., 7., 8.]
147+
]);
148+
let expected_0 = vec![4., 4., 4., 4.];
149+
let expected_1 = vec![1.25, 1.25];
150+
151+
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
152+
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
153+
}
154+
125155
#[test]
126156
fn scale() {
127157
let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);

0 commit comments

Comments
 (0)