@@ -65,8 +65,11 @@ use high_order::HighOrderOperations;
65
65
use lu:: LUDecomposableMatrix ;
66
66
use qr:: QRDecomposableMatrix ;
67
67
use stats:: { MatrixPreprocessing , MatrixStats } ;
68
+ use std:: fs;
68
69
use svd:: SVDDecomposableMatrix ;
69
70
71
+ use crate :: readers;
72
+
70
73
/// Column or row vector
71
74
pub trait BaseVector < T : RealNumber > : Clone + Debug {
72
75
/// Get an element of a vector
@@ -298,9 +301,60 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
298
301
/// represents a row in this matrix.
299
302
type RowVector : BaseVector < T > + Clone + Debug ;
300
303
304
+ /// Create a matrix from a csv file.
305
+ /// ```
306
+ /// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
307
+ /// use smartcore::linalg::BaseMatrix;
308
+ /// use smartcore::readers::csv;
309
+ /// use std::fs;
310
+ ///
311
+ /// fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
312
+ /// assert_eq!(
313
+ /// DenseMatrix::<f64>::from_csv("identity.csv", csv::CSVDefinition::default()).unwrap(),
314
+ /// DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
315
+ /// );
316
+ /// fs::remove_file("identity.csv");
317
+ /// ```
318
+ fn from_csv (
319
+ path : & str ,
320
+ definition : readers:: csv:: CSVDefinition < ' _ > ,
321
+ ) -> Result < Self , readers:: ReadingError > {
322
+ readers:: csv:: matrix_from_csv_source ( fs:: File :: open ( path) ?, definition)
323
+ }
324
+
301
325
/// Transforms row vector `vec` into a 1xM matrix.
302
326
fn from_row_vector ( vec : Self :: RowVector ) -> Self ;
303
327
328
+ /// Transforms Vector of n rows with dimension m into
329
+ /// a matrix nxm.
330
+ /// ```
331
+ /// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
332
+ /// use crate::smartcore::linalg::BaseMatrix;
333
+ ///
334
+ /// let eye = DenseMatrix::from_row_vectors(vec![vec![1., 0., 0.], vec![0., 1., 0.], vec![0., 0., 1.]])
335
+ /// .unwrap();
336
+ ///
337
+ /// assert_eq!(
338
+ /// eye,
339
+ /// DenseMatrix::from_2d_vec(&vec![
340
+ /// vec![1.0, 0.0, 0.0],
341
+ /// vec![0.0, 1.0, 0.0],
342
+ /// vec![0.0, 0.0, 1.0],
343
+ /// ])
344
+ /// );
345
+ fn from_row_vectors ( rows : Vec < Self :: RowVector > ) -> Option < Self > {
346
+ if let Some ( first_row) = rows. first ( ) . cloned ( ) {
347
+ return Some ( rows. iter ( ) . skip ( 1 ) . cloned ( ) . fold (
348
+ Self :: from_row_vector ( first_row) ,
349
+ |current_matrix, new_row| {
350
+ current_matrix. v_stack ( & BaseMatrix :: from_row_vector ( new_row) )
351
+ } ,
352
+ ) ) ;
353
+ } else {
354
+ None
355
+ }
356
+ }
357
+
304
358
/// Transforms 1-d matrix of 1xM into a row vector.
305
359
fn to_row_vector ( self ) -> Self :: RowVector ;
306
360
@@ -782,4 +836,50 @@ mod tests {
782
836
"The second column was not extracted correctly"
783
837
) ;
784
838
}
839
+ mod matrix_from_csv {
840
+
841
+ use crate :: linalg:: naive:: dense_matrix:: DenseMatrix ;
842
+ use crate :: linalg:: BaseMatrix ;
843
+ use crate :: readers:: csv;
844
+ use crate :: readers:: io_testing;
845
+ use crate :: readers:: ReadingError ;
846
+
847
+ #[ test]
848
+ fn simple_read_default_csv ( ) {
849
+ let test_csv_file = io_testing:: TemporaryTextFile :: new (
850
+ "'sepal.length','sepal.width','petal.length','petal.width'\n \
851
+ 5.1,3.5,1.4,0.2\n \
852
+ 4.9,3,1.4,0.2\n \
853
+ 4.7,3.2,1.3,0.2",
854
+ ) ;
855
+
856
+ assert_eq ! (
857
+ DenseMatrix :: <f64 >:: from_csv(
858
+ test_csv_file
859
+ . expect( "Temporary file could not be written." )
860
+ . path( ) ,
861
+ csv:: CSVDefinition :: default ( )
862
+ ) ,
863
+ Ok ( DenseMatrix :: from_2d_array( & [
864
+ & [ 5.1 , 3.5 , 1.4 , 0.2 ] ,
865
+ & [ 4.9 , 3.0 , 1.4 , 0.2 ] ,
866
+ & [ 4.7 , 3.2 , 1.3 , 0.2 ] ,
867
+ ] ) )
868
+ )
869
+ }
870
+
871
+ #[ test]
872
+ fn non_existant_input_file ( ) {
873
+ let potential_error =
874
+ DenseMatrix :: < f64 > :: from_csv ( "/invalid/path" , csv:: CSVDefinition :: default ( ) ) ;
875
+ // The exact message is operating system dependant, therefore, I only test that the correct type
876
+ // error was returned.
877
+ assert_eq ! (
878
+ potential_error. clone( ) ,
879
+ Err ( ReadingError :: CouldNotReadFileSystem {
880
+ msg: String :: from( potential_error. err( ) . unwrap( ) . message( ) . unwrap( ) )
881
+ } )
882
+ )
883
+ }
884
+ }
785
885
}
0 commit comments