Skip to content

Commit 2dfa40f

Browse files
committed
FEAT: Let RandomExt methods apply to array views if possible (sample)
The RandomExt methods for sampling were unintentionally restricted to owned arrays only (like the original random constructors). Now the methods which can also apply to array views.
1 parent 37e4070 commit 2dfa40f

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

ndarray-rand/src/lib.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use crate::rand::seq::index;
3535
use crate::rand::{thread_rng, Rng, SeedableRng};
3636

3737
use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
38-
use ndarray::{ArrayBase, DataOwned, Dimension};
38+
use ndarray::{ArrayBase, DataOwned, RawData, Data, Dimension};
3939
#[cfg(feature = "quickcheck")]
4040
use quickcheck::{Arbitrary, Gen};
4141

@@ -63,7 +63,7 @@ pub mod rand_distr {
6363
/// [`.random_using()`](#tymethod.random_using).
6464
pub trait RandomExt<S, A, D>
6565
where
66-
S: DataOwned<Elem = A>,
66+
S: RawData<Elem = A>,
6767
D: Dimension,
6868
{
6969
/// Create an array with shape `dim` with elements drawn from
@@ -87,6 +87,7 @@ where
8787
fn random<Sh, IdS>(shape: Sh, distribution: IdS) -> ArrayBase<S, D>
8888
where
8989
IdS: Distribution<S::Elem>,
90+
S: DataOwned<Elem = A>,
9091
Sh: ShapeBuilder<Dim = D>;
9192

9293
/// Create an array with shape `dim` with elements drawn from
@@ -117,6 +118,7 @@ where
117118
where
118119
IdS: Distribution<S::Elem>,
119120
R: Rng + ?Sized,
121+
S: DataOwned<Elem = A>,
120122
Sh: ShapeBuilder<Dim = D>;
121123

122124
/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
@@ -163,6 +165,7 @@ where
163165
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
164166
where
165167
A: Copy,
168+
S: Data<Elem = A>,
166169
D: RemoveAxis;
167170

168171
/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
@@ -223,17 +226,19 @@ where
223226
where
224227
R: Rng + ?Sized,
225228
A: Copy,
229+
S: Data<Elem = A>,
226230
D: RemoveAxis;
227231
}
228232

229233
impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D>
230234
where
231-
S: DataOwned<Elem = A>,
235+
S: RawData<Elem = A>,
232236
D: Dimension,
233237
{
234238
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
235239
where
236240
IdS: Distribution<S::Elem>,
241+
S: DataOwned<Elem = A>,
237242
Sh: ShapeBuilder<Dim = D>,
238243
{
239244
Self::random_using(shape, dist, &mut get_rng())
@@ -243,6 +248,7 @@ where
243248
where
244249
IdS: Distribution<S::Elem>,
245250
R: Rng + ?Sized,
251+
S: DataOwned<Elem = A>,
246252
Sh: ShapeBuilder<Dim = D>,
247253
{
248254
Self::from_shape_simple_fn(shape, move || dist.sample(rng))
@@ -251,6 +257,7 @@ where
251257
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
252258
where
253259
A: Copy,
260+
S: Data<Elem = A>,
254261
D: RemoveAxis,
255262
{
256263
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
@@ -266,6 +273,7 @@ where
266273
where
267274
R: Rng + ?Sized,
268275
A: Copy,
276+
S: Data<Elem = A>,
269277
D: RemoveAxis,
270278
{
271279
let indices: Vec<_> = match strategy {

ndarray-rand/tests/tests.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ fn test_dim_f() {
3535
}
3636
}
3737

38+
#[test]
39+
fn sample_axis_on_view() {
40+
let m = 5;
41+
let a = Array::random((m, 4), Uniform::new(0., 2.));
42+
let _samples = a.view().sample_axis(Axis(0), m, SamplingStrategy::WithoutReplacement);
43+
}
44+
3845
#[test]
3946
#[should_panic]
4047
fn oversampling_without_replacement_should_panic() {

0 commit comments

Comments
 (0)