Skip to content

Apply rustfmt #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use thiserror::Error;
use ndarray::ShapeError;

use thiserror::Error;

/// Enum provides error types
#[derive(Error, Debug)]
Expand Down
22 changes: 10 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,43 +120,41 @@
//!

mod errors;
mod traits;
mod ndarrayext;
mod ndg;
mod sprsext;
mod validate;
mod util;
mod traits;
mod umv;
mod ndg;
mod util;
mod validate;

use std::result;

/// Provides result type for `make` and `evaluate` methods
pub type Result<T> = result::Result<T, errors::CsapsError>;

pub use errors::CsapsError;
pub use ndg::{GridCubicSmoothingSpline, NdGridSpline};
pub use traits::{Real, RealRef};
pub use umv::{NdSpline, CubicSmoothingSpline};
pub use ndg::{NdGridSpline, GridCubicSmoothingSpline};

pub use umv::{CubicSmoothingSpline, NdSpline};

// #[cfg(test)]
// mod tests {
// use crate::CubicSmoothingSpline;
// use crate::CubicSmoothingSpline;
// use ndarray::prelude::*;

// #[test]
// fn test_new() {
// fn test_new() {

// let zeros = Array1::<f64>::zeros(1);
// let zeros = Array1::<f64>::zeros(1);

// let x = zeros.view();
// let zeros = Array2::<f64>::zeros((1,1));
// let y = zeros.view();


// let sp = CubicSmoothingSpline::new(x.view(), y.view())
// // .with_optional_weights(weights)
// // .with_optional_smooth(s)
// .make();
// }
// }
// }
189 changes: 103 additions & 86 deletions src/ndarrayext.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
use ndarray::{prelude::*, IntoDimension, Slice};
use itertools::Itertools;

use ndarray::{prelude::*, IntoDimension, Slice};

use crate::{
Result,
util::dim_from_vec,
CsapsError::{ReshapeFrom2d, ReshapeTo2d},
util::dim_from_vec, Real
Real, Result,
};


pub fn diff<'a, T: 'a, D, V>(data: V, axis: Option<Axis>) -> Array<T, D>
where
T: Real<T>,
D: Dimension,
V: AsArray<'a, T, D>
V: AsArray<'a, T, D>,
{
let data_view = data.into();
let axis = axis.unwrap_or_else(|| Axis(data_view.ndim() - 1));
Expand All @@ -24,11 +22,10 @@ where
&tail - &head
}


pub fn to_2d<'a, T: 'a, D, I>(data: I, axis: Axis) -> Result<ArrayView2<'a, T>>
where
D: Dimension,
I: AsArray<'a, T, D>,
where
D: Dimension,
I: AsArray<'a, T, D>,
{
let data_view = data.into();
let ndim = data_view.ndim();
Expand All @@ -48,45 +45,43 @@ pub fn to_2d<'a, T: 'a, D, I>(data: I, axis: Axis) -> Result<ArrayView2<'a, T>>

match data_view.permuted_axes(axes).into_shape(new_shape) {
Ok(view_2d) => Ok(view_2d),
Err(error) => Err(
ReshapeTo2d {
input_shape: shape,
output_shape: new_shape.to_vec(),
axis: axis.0,
source: error,
}
)
Err(error) => Err(ReshapeTo2d {
input_shape: shape,
output_shape: new_shape.to_vec(),
axis: axis.0,
source: error,
}),
}
}


pub fn to_2d_simple<'a, T: 'a, D>(data: ArrayView<'a, T, D>) -> Result<ArrayView2<'a, T>>
where
D: Dimension
where
D: Dimension,
{
let ndim = data.ndim();
let shape = data.shape().to_vec();
let new_shape = [shape[0..(ndim - 1)].iter().product(), shape[ndim - 1]];

match data.into_shape(new_shape) {
Ok(data_2d) => Ok(data_2d),
Err(error) => Err(
ReshapeTo2d {
input_shape: shape,
output_shape: new_shape.to_vec(),
axis: ndim - 1,
source: error,
}
)
Err(error) => Err(ReshapeTo2d {
input_shape: shape,
output_shape: new_shape.to_vec(),
axis: ndim - 1,
source: error,
}),
}
}


pub fn from_2d<'a, T: 'a, D, S, I>(data: I, shape: S, axis: Axis) -> Result<ArrayView<'a, T, S::Dim>>
where
D: Dimension,
S: IntoDimension<Dim = D>,
I: AsArray<'a, T, Ix2>,
pub fn from_2d<'a, T: 'a, D, S, I>(
data: I,
shape: S,
axis: Axis,
) -> Result<ArrayView<'a, T, S::Dim>>
where
D: Dimension,
S: IntoDimension<Dim = D>,
I: AsArray<'a, T, Ix2>,
{
let shape = shape.into_dimension();
let ndim = shape.ndim();
Expand All @@ -106,39 +101,37 @@ pub fn from_2d<'a, T: 'a, D, S, I>(data: I, shape: S, axis: Axis) -> Result<Arra

let axes: D = dim_from_vec(ndim, axes_tmp);
Ok(view_nd.permuted_axes(axes))
},
Err(error) => Err(
ReshapeFrom2d {
input_shape: data_view.shape().to_vec(),
output_shape: new_shape_vec,
axis: axis.0,
source: error,
}
)
}
Err(error) => Err(ReshapeFrom2d {
input_shape: data_view.shape().to_vec(),
output_shape: new_shape_vec,
axis: axis.0,
source: error,
}),
}
}


/// Returns the indices of the bins to which each value in input array belongs
///
/// This code works if `bins` is increasing
pub fn digitize<'a, T: 'a, A, B>(arr: A, bins: B) -> Array1<usize>
where
T: Real<T>,
// T: Clone + NdFloat + AlmostEqual,

A: AsArray<'a, T, Ix1>,
B: AsArray<'a, T, Ix1>,
where
T: Real<T>,
// T: Clone + NdFloat + AlmostEqual,
A: AsArray<'a, T, Ix1>,
B: AsArray<'a, T, Ix1>,
{
let arr_view = arr.into();
let bins_view = bins.into();

let mut indices = Array1::zeros((arr_view.len(),));
let mut kstart: usize = 0;

for (i, &a) in arr_view.iter().enumerate()
.sorted_by(|e1, e2| e1.1.partial_cmp(e2.1).unwrap()) {

for (i, &a) in arr_view
.iter()
.enumerate()
.sorted_by(|e1, e2| e1.1.partial_cmp(e2.1).unwrap())
{
let mut k = kstart;

for bins_win in bins_view.slice(s![kstart..]).windows(2) {
Expand All @@ -158,53 +151,55 @@ pub fn digitize<'a, T: 'a, A, B>(arr: A, bins: B) -> Array1<usize>
indices
}


#[cfg(test)]
mod tests {
use std::f64;
use ndarray::{array, Array1, Axis, Ix1, Ix2, Ix3};
use crate::ndarrayext::*;
use ndarray::{array, Array1, Axis, Ix1, Ix2, Ix3};
use std::f64;

#[test]
fn test_diff_1d() {
let a = array![1., 2., 3., 4., 5.];

assert_eq!(diff(&a, None),
array![1., 1., 1., 1.]);
assert_eq!(diff(&a, None), array![1., 1., 1., 1.]);

assert_eq!(diff(&a, Some(Axis(0))),
array![1., 1., 1., 1.]);
assert_eq!(diff(&a, Some(Axis(0))), array![1., 1., 1., 1.]);
}

#[test]
fn test_diff_2d() {
let a = array![[1., 2., 3., 4.], [1., 2., 3., 4.]];

assert_eq!(diff(&a, None),
array![[1., 1., 1.], [1., 1., 1.]]);
assert_eq!(diff(&a, None), array![[1., 1., 1.], [1., 1., 1.]]);

assert_eq!(diff(&a, Some(Axis(0))),
array![[0., 0., 0., 0.]]);
assert_eq!(diff(&a, Some(Axis(0))), array![[0., 0., 0., 0.]]);

assert_eq!(diff(&a, Some(Axis(1))),
array![[1., 1., 1.], [1., 1., 1.]]);
assert_eq!(diff(&a, Some(Axis(1))), array![[1., 1., 1.], [1., 1., 1.]]);
}

#[test]
fn test_diff_3d() {
let a = array![[[1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.]]];

assert_eq!(diff(&a, None),
array![[[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]]);

assert_eq!(diff(&a, Some(Axis(0))),
array![[[0., 0., 0.], [0., 0., 0.]]]);

assert_eq!(diff(&a, Some(Axis(1))),
array![[[0., 0., 0.]], [[0., 0., 0.]]]);

assert_eq!(diff(&a, Some(Axis(2))),
array![[[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]]);
assert_eq!(
diff(&a, None),
array![[[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]]
);

assert_eq!(
diff(&a, Some(Axis(0))),
array![[[0., 0., 0.], [0., 0., 0.]]]
);

assert_eq!(
diff(&a, Some(Axis(1))),
array![[[0., 0., 0.]], [[0., 0., 0.]]]
);

assert_eq!(
diff(&a, Some(Axis(2))),
array![[[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]]
);
}

#[test]
Expand All @@ -218,8 +213,14 @@ mod tests {
fn test_to_2d_from_2d() {
let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];

assert_eq!(to_2d(&a, Axis(0)).unwrap(), array![[1, 5], [2, 6], [3, 7], [4, 8]]);
assert_eq!(to_2d(&a, Axis(1)).unwrap(), array![[1, 2, 3, 4], [5, 6, 7, 8]]);
assert_eq!(
to_2d(&a, Axis(0)).unwrap(),
array![[1, 5], [2, 6], [3, 7], [4, 8]]
);
assert_eq!(
to_2d(&a, Axis(1)).unwrap(),
array![[1, 2, 3, 4], [5, 6, 7, 8]]
);
}

#[test]
Expand All @@ -229,7 +230,10 @@ mod tests {
// FIXME: incompatible memory layout
// assert_eq!(to_2d(&a, Axis(0)).unwrap(), array![[1, 7], [2, 8], [3, 9], [4, 10], [5, 11], [6, 12]]);
// assert_eq!(to_2d(&a, Axis(1)).unwrap(), array![[1, 4], [2, 5], [3, 6], [7, 10], [8, 11], [9, 12]]);
assert_eq!(to_2d(&a, Axis(2)).unwrap(), array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);
assert_eq!(
to_2d(&a, Axis(2)).unwrap(),
array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
);
}

#[test]
Expand All @@ -241,13 +245,19 @@ mod tests {
#[test]
fn test_to_2d_simple_from_2d() {
let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
assert_eq!(to_2d_simple(a.view()).unwrap(), array![[1, 2, 3, 4], [5, 6, 7, 8]]);
assert_eq!(
to_2d_simple(a.view()).unwrap(),
array![[1, 2, 3, 4], [5, 6, 7, 8]]
);
}

#[test]
fn test_to_2d_simple_from_3d() {
let a = array![[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]];
assert_eq!(to_2d_simple(a.view()).unwrap(), array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);
assert_eq!(
to_2d_simple(a.view()).unwrap(),
array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
);
}

#[test]
Expand All @@ -258,7 +268,8 @@ mod tests {

let r = from_2d(&a, s, Axis(2))
.unwrap()
.into_dimensionality::<Ix3>().unwrap();
.into_dimensionality::<Ix3>()
.unwrap();

assert_eq!(r, e);
}
Expand Down Expand Up @@ -369,7 +380,10 @@ mod tests {

let indices = digitize(&xi, &edges);

assert_eq!(indices, array![1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
assert_eq!(
indices,
array![1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]
)
}

#[test]
Expand All @@ -389,6 +403,9 @@ mod tests {

let indices = digitize(&xi, &edges);

assert_eq!(indices, array![0, 1, 0, 2, 2, 1, 0, 3, 4, 4, 3, 3, 2, 2, 1, 0])
assert_eq!(
indices,
array![0, 1, 0, 2, 2, 1, 0, 3, 4, 4, 3, 3, 2, 2, 1, 0]
)
}
}
Loading
Loading