Skip to content

Commit 89a5136

Browse files
authored
Change implementation of to_row_vector for nalgebra (#34)
* Add failing test * Change implementation of to_row_vector for nalgebra
1 parent 9db9939 commit 89a5136

File tree

4 files changed

+23
-4
lines changed

4 files changed

+23
-4
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ datasets = []
2020

2121
[dependencies]
2222
ndarray = { version = "0.13", optional = true }
23-
nalgebra = { version = "0.22.0", optional = true }
23+
nalgebra = { version = "0.23.0", optional = true }
2424
num-traits = "0.2.12"
2525
num = "0.3.0"
2626
rand = "0.7.3"
@@ -35,4 +35,4 @@ bincode = "1.3.1"
3535

3636
[[bench]]
3737
name = "distance"
38-
harness = false
38+
harness = false

src/linalg/naive/dense_matrix.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,12 @@ mod tests {
10641064
);
10651065
}
10661066

1067+
#[test]
1068+
fn col_matrix_to_row_vector() {
1069+
let m: DenseMatrix<f64> = BaseMatrix::zeros(10, 1);
1070+
assert_eq!(m.to_row_vector().len(), 10)
1071+
}
1072+
10671073
#[test]
10681074
fn iter() {
10691075
let vec = vec![1., 2., 3., 4., 5., 6.];

src/linalg/nalgebra_bindings.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,15 @@ impl<T: RealNumber + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
185185
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
186186
BaseMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
187187
{
188-
type RowVector = MatrixMN<T, U1, Dynamic>;
188+
type RowVector = RowDVector<T>;
189189

190190
fn from_row_vector(vec: Self::RowVector) -> Self {
191191
Matrix::from_rows(&[vec])
192192
}
193193

194194
fn to_row_vector(self) -> Self::RowVector {
195-
self.row(0).into_owned()
195+
let (nrows, ncols) = self.shape();
196+
self.reshape_generic(U1, Dynamic::new(nrows * ncols))
196197
}
197198

198199
fn get(&self, row: usize, col: usize) -> T {
@@ -697,6 +698,12 @@ mod tests {
697698
assert_eq!(m.to_row_vector(), expected);
698699
}
699700

701+
#[test]
702+
fn col_matrix_to_row_vector() {
703+
let m: DMatrix<f64> = BaseMatrix::zeros(10, 1);
704+
assert_eq!(m.to_row_vector().len(), 10)
705+
}
706+
700707
#[test]
701708
fn get_row_col_as_vec() {
702709
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);

src/linalg/ndarray_bindings.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,12 @@ mod tests {
563563
);
564564
}
565565

566+
#[test]
567+
fn col_matrix_to_row_vector() {
568+
let m: Array2<f64> = BaseMatrix::zeros(10, 1);
569+
assert_eq!(m.to_row_vector().len(), 10)
570+
}
571+
566572
#[test]
567573
fn add_mut() {
568574
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);

0 commit comments

Comments
 (0)