Skip to content

Commit b6f585e

Browse files
authored
Implement a generic read_csv method (#147)
* feat: Add interface to build `Matrix` from rows. * feat: Add option to derive `RealNumber` from string. To construct a `Matrix` from csv, and therefore from string, I need to be able to deserialize a generic `RealNumber` from string. * feat: Implement `Matrix::read_csv`.
1 parent 4685fc7 commit b6f585e

File tree

7 files changed

+841
-0
lines changed

7 files changed

+841
-0
lines changed

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ pub mod neighbors;
9595
pub(crate) mod optimization;
9696
/// Preprocessing utilities
9797
pub mod preprocessing;
98+
/// Reading in Data.
99+
pub mod readers;
98100
/// Support Vector Machines
99101
pub mod svm;
100102
/// Supervised tree-based learning methods

src/linalg/mod.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,11 @@ use high_order::HighOrderOperations;
6565
use lu::LUDecomposableMatrix;
6666
use qr::QRDecomposableMatrix;
6767
use stats::{MatrixPreprocessing, MatrixStats};
68+
use std::fs;
6869
use svd::SVDDecomposableMatrix;
6970

71+
use crate::readers;
72+
7073
/// Column or row vector
7174
pub trait BaseVector<T: RealNumber>: Clone + Debug {
7275
/// Get an element of a vector
@@ -298,9 +301,60 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
298301
/// represents a row in this matrix.
299302
type RowVector: BaseVector<T> + Clone + Debug;
300303

304+
/// Create a matrix from a csv file.
305+
/// ```
306+
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
307+
/// use smartcore::linalg::BaseMatrix;
308+
/// use smartcore::readers::csv;
309+
/// use std::fs;
310+
///
311+
/// fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
312+
/// assert_eq!(
313+
/// DenseMatrix::<f64>::from_csv("identity.csv", csv::CSVDefinition::default()).unwrap(),
314+
/// DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
315+
/// );
316+
/// fs::remove_file("identity.csv");
317+
/// ```
318+
fn from_csv(
319+
path: &str,
320+
definition: readers::csv::CSVDefinition<'_>,
321+
) -> Result<Self, readers::ReadingError> {
322+
readers::csv::matrix_from_csv_source(fs::File::open(path)?, definition)
323+
}
324+
301325
/// Transforms row vector `vec` into a 1xM matrix.
302326
fn from_row_vector(vec: Self::RowVector) -> Self;
303327

328+
/// Transforms Vector of n rows with dimension m into
329+
/// a matrix nxm.
330+
/// ```
331+
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
332+
/// use crate::smartcore::linalg::BaseMatrix;
333+
///
334+
/// let eye = DenseMatrix::from_row_vectors(vec![vec![1., 0., 0.], vec![0., 1., 0.], vec![0., 0., 1.]])
335+
/// .unwrap();
336+
///
337+
/// assert_eq!(
338+
/// eye,
339+
/// DenseMatrix::from_2d_vec(&vec![
340+
/// vec![1.0, 0.0, 0.0],
341+
/// vec![0.0, 1.0, 0.0],
342+
/// vec![0.0, 0.0, 1.0],
343+
/// ])
344+
/// );
345+
fn from_row_vectors(rows: Vec<Self::RowVector>) -> Option<Self> {
346+
if let Some(first_row) = rows.first().cloned() {
347+
return Some(rows.iter().skip(1).cloned().fold(
348+
Self::from_row_vector(first_row),
349+
|current_matrix, new_row| {
350+
current_matrix.v_stack(&BaseMatrix::from_row_vector(new_row))
351+
},
352+
));
353+
} else {
354+
None
355+
}
356+
}
357+
304358
/// Transforms 1-d matrix of 1xM into a row vector.
305359
fn to_row_vector(self) -> Self::RowVector;
306360

@@ -782,4 +836,50 @@ mod tests {
782836
"The second column was not extracted correctly"
783837
);
784838
}
839+
mod matrix_from_csv {
840+
841+
use crate::linalg::naive::dense_matrix::DenseMatrix;
842+
use crate::linalg::BaseMatrix;
843+
use crate::readers::csv;
844+
use crate::readers::io_testing;
845+
use crate::readers::ReadingError;
846+
847+
#[test]
848+
fn simple_read_default_csv() {
849+
let test_csv_file = io_testing::TemporaryTextFile::new(
850+
"'sepal.length','sepal.width','petal.length','petal.width'\n\
851+
5.1,3.5,1.4,0.2\n\
852+
4.9,3,1.4,0.2\n\
853+
4.7,3.2,1.3,0.2",
854+
);
855+
856+
assert_eq!(
857+
DenseMatrix::<f64>::from_csv(
858+
test_csv_file
859+
.expect("Temporary file could not be written.")
860+
.path(),
861+
csv::CSVDefinition::default()
862+
),
863+
Ok(DenseMatrix::from_2d_array(&[
864+
&[5.1, 3.5, 1.4, 0.2],
865+
&[4.9, 3.0, 1.4, 0.2],
866+
&[4.7, 3.2, 1.3, 0.2],
867+
]))
868+
)
869+
}
870+
871+
#[test]
872+
fn non_existant_input_file() {
873+
let potential_error =
874+
DenseMatrix::<f64>::from_csv("/invalid/path", csv::CSVDefinition::default());
875+
// The exact message is operating system dependant, therefore, I only test that the correct type
876+
// error was returned.
877+
assert_eq!(
878+
potential_error.clone(),
879+
Err(ReadingError::CouldNotReadFileSystem {
880+
msg: String::from(potential_error.err().unwrap().message().unwrap())
881+
})
882+
)
883+
}
884+
}
785885
}

src/math/num.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use rand::prelude::*;
77
use std::fmt::{Debug, Display};
88
use std::iter::{Product, Sum};
99
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
10+
use std::str::FromStr;
1011

1112
/// Defines real number
1213
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
@@ -22,6 +23,7 @@ pub trait RealNumber:
2223
+ SubAssign
2324
+ MulAssign
2425
+ DivAssign
26+
+ FromStr
2527
{
2628
/// Copy sign from `sign` - another real number
2729
fn copysign(self, sign: Self) -> Self;
@@ -154,4 +156,14 @@ mod tests {
154156
assert_eq!(41.0.sigmoid(), 1.);
155157
assert_eq!((-41.0).sigmoid(), 0.);
156158
}
159+
160+
#[test]
161+
fn f32_from_string() {
162+
assert_eq!(f32::from_str("1.111111").unwrap(), 1.111111)
163+
}
164+
165+
#[test]
166+
fn f64_from_string() {
167+
assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111)
168+
}
157169
}

0 commit comments

Comments
 (0)