Skip to content

Commit d4e9669

Browse files
committed
Add functional methods to DenseMatrix implementation
1 parent 724268b commit d4e9669

File tree

1 file changed

+44
-64
lines changed

1 file changed

+44
-64
lines changed

src/linalg/basic/matrix.rs

Lines changed: 44 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -51,67 +51,20 @@ pub struct DenseMatrixMutView<'a, T: Debug + Display + Copy + Sized> {
5151
column_major: bool,
5252
}
5353

54-
// functional utility functions used across types
55-
fn is_valid_matrix_window(
56-
mrows: usize,
57-
mcols: usize,
58-
vrows: &Range<usize>,
59-
vcols: &Range<usize>,
60-
) -> bool {
61-
debug_assert!(
62-
vrows.end <= mrows && vcols.end <= mcols,
63-
"The window end is outside of the matrix range"
64-
);
65-
debug_assert!(
66-
vrows.start <= mrows && vcols.start <= mcols,
67-
"The window start is outside of the matrix range"
68-
);
69-
debug_assert!(
70-
// depends on a properly formed range
71-
vrows.start <= vrows.end && vcols.start <= vcols.end,
72-
"Invalid range: start <= end failed"
73-
);
74-
75-
!(vrows.end <= mrows && vcols.end <= mcols && vrows.start <= mrows && vcols.start <= mcols)
76-
}
77-
fn start_end_stride(
78-
mrows: usize,
79-
mcols: usize,
80-
vrows: &Range<usize>,
81-
vcols: &Range<usize>,
82-
column_major: bool,
83-
) -> (usize, usize, usize) {
84-
let (start, end, stride) = if column_major {
85-
(
86-
vrows.start + vcols.start * mrows,
87-
vrows.end + (vcols.end - 1) * mrows,
88-
mrows,
89-
)
90-
} else {
91-
(
92-
vrows.start * mcols + vcols.start,
93-
(vrows.end - 1) * mcols + vcols.end,
94-
mcols,
95-
)
96-
};
97-
(start, end, stride)
98-
}
9954

10055
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
10156
fn new(
10257
m: &'a DenseMatrix<T>,
10358
vrows: Range<usize>,
10459
vcols: Range<usize>,
10560
) -> Result<Self, Failed> {
106-
let (mrows, mcols) = m.shape();
107-
108-
if is_valid_matrix_window(mrows, mcols, &vrows, &vcols) {
61+
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
10962
Err(Failed::input(
110-
"The specified window is outside of the matrix range",
63+
"The specified view is outside of the matrix range"
11164
))
11265
} else {
11366
let (start, end, stride) =
114-
start_end_stride(mrows, mcols, &vrows, &vcols, m.column_major);
67+
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
11568

11669
Ok(DenseMatrixView {
11770
values: &m.values[start..end],
@@ -157,14 +110,13 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
157110
vrows: Range<usize>,
158111
vcols: Range<usize>,
159112
) -> Result<Self, Failed> {
160-
let (mrows, mcols) = m.shape();
161-
if is_valid_matrix_window(mrows, mcols, &vrows, &vcols) {
113+
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
162114
Err(Failed::input(
163-
"The specified window is outside of the matrix range",
115+
"The specified view is outside of the matrix range"
164116
))
165117
} else {
166118
let (start, end, stride) =
167-
start_end_stride(mrows, mcols, &vrows, &vcols, m.column_major);
119+
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
168120

169121
Ok(DenseMatrixMutView {
170122
values: &mut m.values[start..end],
@@ -239,10 +191,6 @@ impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
239191
values: Vec<T>,
240192
column_major: bool,
241193
) -> Result<Self, Failed> {
242-
debug_assert!(
243-
nrows * ncols == values.len(),
244-
"Instantiatint DenseMatrix requires nrows * ncols == values.len()"
245-
);
246194
let data_len = values.len();
247195
if nrows * ncols != values.len() {
248196
Err(Failed::input(&format!(
@@ -265,14 +213,9 @@ impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
265213

266214
/// New instance of `DenseMatrix` from 2d vector.
267215
pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Result<Self, Failed> {
268-
debug_assert!(
269-
!(values.is_empty() || values[0].is_empty()),
270-
"Instantiating DenseMatrix requires a non-empty 2d_vec"
271-
);
272-
273216
if values.is_empty() || values[0].is_empty() {
274217
Err(Failed::input(
275-
"The 2d vec provided is empty; cannot instantiate the matrix",
218+
"The 2d vec provided is empty; cannot instantiate the matrix"
276219
))
277220
} else {
278221
let nrows = values.len();
@@ -298,6 +241,43 @@ impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
298241
pub fn iter(&self) -> Iter<'_, T> {
299242
self.values.iter()
300243
}
244+
245+
/// Check if the size of the requested view is bounded to matrix rows/cols count
246+
fn is_valid_view(
247+
&self,
248+
n_rows: usize,
249+
n_cols: usize,
250+
vrows: &Range<usize>,
251+
vcols: &Range<usize>,
252+
) -> bool {
253+
!(vrows.end <= n_rows && vcols.end <= n_cols && vrows.start <= n_rows && vcols.start <= n_cols)
254+
}
255+
256+
/// Compute the range of the requested view: start, end, size of the slice
257+
fn stride_range(
258+
&self,
259+
n_rows: usize,
260+
n_cols: usize,
261+
vrows: &Range<usize>,
262+
vcols: &Range<usize>,
263+
column_major: bool,
264+
) -> (usize, usize, usize) {
265+
let (start, end, stride) = if column_major {
266+
(
267+
vrows.start + vcols.start * n_rows,
268+
vrows.end + (vcols.end - 1) * n_rows,
269+
n_rows,
270+
)
271+
} else {
272+
(
273+
vrows.start * n_cols + vcols.start,
274+
(vrows.end - 1) * n_cols + vcols.end,
275+
n_cols,
276+
)
277+
};
278+
(start, end, stride)
279+
}
280+
301281
}
302282

303283
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> {

0 commit comments

Comments
 (0)