Skip to content

Commit c21e752

Browse files
authored
feat: allocate first and then proceed to create matrix from Vec of Ro… (#159)
* feat: allocate first and then proceed to create matrix from Vec of RowVectors
1 parent 6a2e104 commit c21e752

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

src/linalg/mod.rs

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,19 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
343343
/// ])
344344
/// );
345345
fn from_row_vectors(rows: Vec<Self::RowVector>) -> Option<Self> {
346-
if let Some(first_row) = rows.first().cloned() {
347-
return Some(rows.iter().skip(1).cloned().fold(
348-
Self::from_row_vector(first_row),
349-
|current_matrix, new_row| {
350-
current_matrix.v_stack(&BaseMatrix::from_row_vector(new_row))
351-
},
352-
));
353-
} else {
354-
None
346+
if rows.is_empty() {
347+
return None;
348+
}
349+
let n = rows.len();
350+
let m = rows[0].len();
351+
352+
let mut result = Self::zeros(n, m);
353+
354+
for (row_idx, row) in rows.into_iter().enumerate() {
355+
result.set_row(row_idx, row);
355356
}
357+
358+
Some(result)
356359
}
357360

358361
/// Transforms 1-d matrix of 1xM into a row vector.
@@ -376,6 +379,13 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
376379
/// * `result` - receiver for the row
377380
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>);
378381

382+
/// Set row vector at row `row_idx`.
383+
fn set_row(&mut self, row_idx: usize, row: Self::RowVector) {
384+
for (col_idx, val) in row.to_vec().into_iter().enumerate() {
385+
self.set(row_idx, col_idx, val);
386+
}
387+
}
388+
379389
/// Get a vector with elements of the `col`'th column
380390
/// * `col` - column number
381391
fn get_col_as_vec(&self, col: usize) -> Vec<T>;
@@ -836,6 +846,32 @@ mod tests {
836846
"The second column was not extracted correctly"
837847
);
838848
}
849+
850+
#[test]
851+
fn test_from_row_vectors_simple() {
852+
let eye = DenseMatrix::from_row_vectors(vec![
853+
vec![1., 0., 0.],
854+
vec![0., 1., 0.],
855+
vec![0., 0., 1.],
856+
])
857+
.unwrap();
858+
assert_eq!(
859+
eye,
860+
DenseMatrix::from_2d_vec(&vec![
861+
vec![1.0, 0.0, 0.0],
862+
vec![0.0, 1.0, 0.0],
863+
vec![0.0, 0.0, 1.0],
864+
])
865+
);
866+
}
867+
868+
#[test]
869+
fn test_from_row_vectors_large() {
870+
let eye = DenseMatrix::from_row_vectors(vec![vec![4.25; 5000]; 5000]).unwrap();
871+
872+
assert_eq!(eye.shape(), (5000, 5000));
873+
assert_eq!(eye.get_row(5), vec![4.25; 5000]);
874+
}
839875
mod matrix_from_csv {
840876

841877
use crate::linalg::naive::dense_matrix::DenseMatrix;

0 commit comments

Comments
 (0)