Skip to content

Commit 900078c

Browse files
committed
Implement abstract method to convert a slice to a BaseVector, Implement RealNumberVector over BaseVector instead of over Vec<T>
1 parent 82464f4 commit 900078c

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

src/linalg/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,21 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
8383
self.len() == 0
8484
}
8585

86+
/// Create a new vector from a &[T]
87+
/// ```
88+
/// use smartcore::linalg::naive::dense_matrix::*;
89+
/// let slice: &[f64] = &[0., 0.5, 2., 3., 4.];
90+
/// let a: Vec<f64> = BaseVector::from_slice(slice);
91+
/// assert_eq!(a, vec![0., 0.5, 2., 3., 4.]);
92+
/// ```
93+
fn from_slice(f: &[T]) -> Self {
94+
let mut v = Self::zeros(f.len());
95+
for (i, elem) in f.iter().enumerate() {
96+
v.set(i, *elem);
97+
}
98+
v
99+
}
100+
86101
/// Return a vector with the elements of the one-dimensional array.
87102
fn to_vec(&self) -> Vec<T>;
88103

src/math/vector.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use crate::math::num::RealNumber;
22
use std::collections::HashMap;
33

4+
use crate::linalg::BaseVector;
45
pub trait RealNumberVector<T: RealNumber> {
56
fn unique(&self) -> (Vec<T>, Vec<usize>);
67
}
78

8-
impl<T: RealNumber> RealNumberVector<T> for Vec<T> {
9+
impl<T: RealNumber, V: BaseVector<T>> RealNumberVector<T> for V {
910
fn unique(&self) -> (Vec<T>, Vec<usize>) {
10-
let mut unique = self.clone();
11+
let mut unique = self.to_vec();
1112
unique.sort_by(|a, b| a.partial_cmp(b).unwrap());
1213
unique.dedup();
1314

@@ -17,8 +18,8 @@ impl<T: RealNumber> RealNumberVector<T> for Vec<T> {
1718
}
1819

1920
let mut unique_index = Vec::with_capacity(self.len());
20-
for e in self {
21-
unique_index.push(index[&e.to_i64().unwrap()]);
21+
for idx in 0..self.len() {
22+
unique_index.push(index[&self.get(idx).to_i64().unwrap()]);
2223
}
2324

2425
(unique, unique_index)
@@ -27,7 +28,7 @@ impl<T: RealNumber> RealNumberVector<T> for Vec<T> {
2728

2829
#[cfg(test)]
2930
mod tests {
30-
use super::*;
31+
use super::RealNumberVector;
3132

3233
#[test]
3334
fn unique() {

src/naive_bayes/mod.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,7 @@ impl<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> BaseNaiveBayes<T, M,
5858
*prediction
5959
})
6060
.collect::<Vec<T>>();
61-
let mut y_hat = M::RowVector::zeros(rows);
62-
for (i, prediction) in predictions.iter().enumerate().take(rows) {
63-
y_hat.set(i, *prediction);
64-
}
61+
let y_hat = M::RowVector::from_slice(&predictions);
6562
Ok(y_hat)
6663
}
6764
}

0 commit comments

Comments
 (0)