Skip to content

Commit ad7efff

Browse files
authored
Merge pull request #724 from LukeMathWalker/random-convenient-functions
Add lane sampling to ndarray-rand
2 parents ecb7643 + 0c9a2e3 commit ad7efff

File tree

5 files changed

+277
-12
lines changed

5 files changed

+277
-12
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ rawpointer = { version = "0.2" }
4646

4747
[dev-dependencies]
4848
defmac = "0.2"
49-
quickcheck = { version = "0.8", default-features = false }
49+
quickcheck = { version = "0.9", default-features = false }
5050
approx = "0.3.2"
5151
itertools = { version = "0.8.0", default-features = false, features = ["use_std"] }
5252

ndarray-rand/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ keywords = ["multidimensional", "matrix", "rand", "ndarray"]
1616
[dependencies]
1717
ndarray = { version = "0.13", path = ".." }
1818
rand_distr = "0.2.1"
19+
quickcheck = { version = "0.9", default-features = false, optional = true }
1920

2021
[dependencies.rand]
2122
version = "0.7.0"
2223
features = ["small_rng"]
2324

2425
[dev-dependencies]
2526
rand_isaac = "0.2.0"
27+
quickcheck = { version = "0.9", default-features = false }
2628

2729
[package.metadata.release]
2830
no-dev-version = true

ndarray-rand/src/lib.rs

Lines changed: 176 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@
2929
//! that the items are not compatible (e.g. that a type doesn't implement a
3030
//! necessary trait).
3131
32-
use crate::rand::distributions::Distribution;
32+
use crate::rand::distributions::{Distribution, Uniform};
3333
use crate::rand::rngs::SmallRng;
34+
use crate::rand::seq::index;
3435
use crate::rand::{thread_rng, Rng, SeedableRng};
3536

36-
use ndarray::ShapeBuilder;
37+
use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
3738
use ndarray::{ArrayBase, DataOwned, Dimension};
39+
#[cfg(feature = "quickcheck")]
40+
use quickcheck::{Arbitrary, Gen};
3841

3942
/// [`rand`](https://docs.rs/rand/0.7), re-exported for convenience and version-compatibility.
4043
pub mod rand {
@@ -59,9 +62,9 @@ pub mod rand_distr {
5962
/// low-quality random numbers, and reproducibility is not guaranteed. See its
6063
/// documentation for information. You can select a different RNG with
6164
/// [`.random_using()`](#tymethod.random_using).
62-
pub trait RandomExt<S, D>
65+
pub trait RandomExt<S, A, D>
6366
where
64-
S: DataOwned,
67+
S: DataOwned<Elem = A>,
6568
D: Dimension,
6669
{
6770
/// Create an array with shape `dim` with elements drawn from
@@ -116,21 +119,125 @@ where
116119
IdS: Distribution<S::Elem>,
117120
R: Rng + ?Sized,
118121
Sh: ShapeBuilder<Dim = D>;
122+
123+
/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
124+
///
125+
/// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
126+
/// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
127+
///
128+
/// ***Panics*** when:
129+
/// - creation of the RNG fails;
130+
/// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
131+
/// - length of `axis` is 0.
132+
///
133+
/// ```
134+
/// use ndarray::{array, Axis};
135+
/// use ndarray_rand::{RandomExt, SamplingStrategy};
136+
///
137+
/// # fn main() {
138+
/// let a = array![
139+
/// [1., 2., 3.],
140+
/// [4., 5., 6.],
141+
/// [7., 8., 9.],
142+
/// [10., 11., 12.],
143+
/// ];
144+
/// // Sample 2 rows, without replacement
145+
/// let sample_rows = a.sample_axis(Axis(0), 2, SamplingStrategy::WithoutReplacement);
146+
/// println!("{:?}", sample_rows);
147+
/// // Example Output: (1st and 3rd rows)
148+
/// // [
149+
/// // [1., 2., 3.],
150+
/// // [7., 8., 9.]
151+
/// // ]
152+
/// // Sample 2 columns, with replacement
153+
/// let sample_columns = a.sample_axis(Axis(1), 1, SamplingStrategy::WithReplacement);
154+
/// println!("{:?}", sample_columns);
155+
/// // Example Output: (2nd column, sampled twice)
156+
/// // [
157+
/// // [2., 2.],
158+
/// // [5., 5.],
159+
/// // [8., 8.],
160+
/// // [11., 11.]
161+
/// // ]
162+
/// # }
163+
/// ```
164+
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
165+
where
166+
A: Copy,
167+
D: RemoveAxis;
168+
169+
/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
170+
///
171+
/// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
172+
/// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
173+
///
174+
/// ***Panics*** when:
175+
/// - creation of the RNG fails;
176+
/// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
177+
/// - length of `axis` is 0.
178+
///
179+
/// ```
180+
/// use ndarray::{array, Axis};
181+
/// use ndarray_rand::{RandomExt, SamplingStrategy};
182+
/// use ndarray_rand::rand::SeedableRng;
183+
/// use rand_isaac::isaac64::Isaac64Rng;
184+
///
185+
/// # fn main() {
186+
/// // Get a seeded random number generator for reproducibility (Isaac64 algorithm)
187+
/// let seed = 42;
188+
/// let mut rng = Isaac64Rng::seed_from_u64(seed);
189+
///
190+
/// let a = array![
191+
/// [1., 2., 3.],
192+
/// [4., 5., 6.],
193+
/// [7., 8., 9.],
194+
/// [10., 11., 12.],
195+
/// ];
196+
/// // Sample 2 rows, without replacement
197+
/// let sample_rows = a.sample_axis_using(Axis(0), 2, SamplingStrategy::WithoutReplacement, &mut rng);
198+
/// println!("{:?}", sample_rows);
199+
/// // Example Output: (1st and 3rd rows)
200+
/// // [
201+
/// // [1., 2., 3.],
202+
/// // [7., 8., 9.]
203+
/// // ]
204+
///
205+
/// // Sample 2 columns, with replacement
206+
/// let sample_columns = a.sample_axis_using(Axis(1), 1, SamplingStrategy::WithReplacement, &mut rng);
207+
/// println!("{:?}", sample_columns);
208+
/// // Example Output: (2nd column, sampled twice)
209+
/// // [
210+
/// // [2., 2.],
211+
/// // [5., 5.],
212+
/// // [8., 8.],
213+
/// // [11., 11.]
214+
/// // ]
215+
/// # }
216+
/// ```
217+
fn sample_axis_using<R>(
218+
&self,
219+
axis: Axis,
220+
n_samples: usize,
221+
strategy: SamplingStrategy,
222+
rng: &mut R,
223+
) -> Array<A, D>
224+
where
225+
R: Rng + ?Sized,
226+
A: Copy,
227+
D: RemoveAxis;
119228
}
120229

121-
impl<S, D> RandomExt<S, D> for ArrayBase<S, D>
230+
impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D>
122231
where
123-
S: DataOwned,
232+
S: DataOwned<Elem = A>,
124233
D: Dimension,
125234
{
126235
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
127236
where
128237
IdS: Distribution<S::Elem>,
129238
Sh: ShapeBuilder<Dim = D>,
130239
{
131-
let mut rng =
132-
SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed");
133-
Self::random_using(shape, dist, &mut rng)
240+
Self::random_using(shape, dist, &mut get_rng())
134241
}
135242

136243
fn random_using<Sh, IdS, R>(shape: Sh, dist: IdS, rng: &mut R) -> ArrayBase<S, D>
@@ -141,6 +248,66 @@ where
141248
{
142249
Self::from_shape_fn(shape, |_| dist.sample(rng))
143250
}
251+
252+
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
253+
where
254+
A: Copy,
255+
D: RemoveAxis,
256+
{
257+
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
258+
}
259+
260+
fn sample_axis_using<R>(
261+
&self,
262+
axis: Axis,
263+
n_samples: usize,
264+
strategy: SamplingStrategy,
265+
rng: &mut R,
266+
) -> Array<A, D>
267+
where
268+
R: Rng + ?Sized,
269+
A: Copy,
270+
D: RemoveAxis,
271+
{
272+
let indices: Vec<_> = match strategy {
273+
SamplingStrategy::WithReplacement => {
274+
let distribution = Uniform::from(0..self.len_of(axis));
275+
(0..n_samples).map(|_| distribution.sample(rng)).collect()
276+
}
277+
SamplingStrategy::WithoutReplacement => {
278+
index::sample(rng, self.len_of(axis), n_samples).into_vec()
279+
}
280+
};
281+
self.select(axis, &indices)
282+
}
283+
}
284+
285+
/// Used as parameter in [`sample_axis`] and [`sample_axis_using`] to determine
286+
/// if lanes from the original array should only be sampled once (*without replacement*) or
287+
/// multiple times (*with replacement*).
288+
///
289+
/// [`sample_axis`]: trait.RandomExt.html#tymethod.sample_axis
290+
/// [`sample_axis_using`]: trait.RandomExt.html#tymethod.sample_axis_using
291+
#[derive(Debug, Clone)]
292+
pub enum SamplingStrategy {
293+
WithReplacement,
294+
WithoutReplacement,
295+
}
296+
297+
// `Arbitrary` enables `quickcheck` to generate random `SamplingStrategy` values for testing.
298+
#[cfg(feature = "quickcheck")]
299+
impl Arbitrary for SamplingStrategy {
300+
fn arbitrary<G: Gen>(g: &mut G) -> Self {
301+
if g.gen_bool(0.5) {
302+
SamplingStrategy::WithReplacement
303+
} else {
304+
SamplingStrategy::WithoutReplacement
305+
}
306+
}
307+
}
308+
309+
fn get_rng() -> SmallRng {
310+
SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed")
144311
}
145312

146313
/// A wrapper type that allows casting f64 distributions to f32

ndarray-rand/tests/tests.rs

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
use ndarray::Array;
1+
use ndarray::{Array, Array2, ArrayView1, Axis};
2+
#[cfg(feature = "quickcheck")]
3+
use ndarray_rand::rand::{distributions::Distribution, thread_rng};
24
use ndarray_rand::rand_distr::Uniform;
3-
use ndarray_rand::RandomExt;
5+
use ndarray_rand::{RandomExt, SamplingStrategy};
6+
use quickcheck::quickcheck;
47

58
#[test]
69
fn test_dim() {
@@ -14,3 +17,94 @@ fn test_dim() {
1417
}
1518
}
1619
}
20+
21+
#[test]
22+
#[should_panic]
23+
fn oversampling_without_replacement_should_panic() {
24+
let m = 5;
25+
let a = Array::random((m, 4), Uniform::new(0., 2.));
26+
let _samples = a.sample_axis(Axis(0), m + 1, SamplingStrategy::WithoutReplacement);
27+
}
28+
29+
quickcheck! {
30+
fn oversampling_with_replacement_is_fine(m: usize, n: usize) -> bool {
31+
let a = Array::random((m, n), Uniform::new(0., 2.));
32+
// Higher than the length of both axes
33+
let n_samples = m + n + 1;
34+
35+
// We don't want to deal with sampling from 0-length axes in this test
36+
if m != 0 {
37+
if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(0), n_samples) {
38+
return false;
39+
}
40+
}
41+
42+
// We don't want to deal with sampling from 0-length axes in this test
43+
if n != 0 {
44+
if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(1), n_samples) {
45+
return false;
46+
}
47+
}
48+
49+
true
50+
}
51+
}
52+
53+
#[cfg(feature = "quickcheck")]
54+
quickcheck! {
55+
fn sampling_behaves_as_expected(m: usize, n: usize, strategy: SamplingStrategy) -> bool {
56+
let a = Array::random((m, n), Uniform::new(0., 2.));
57+
let mut rng = &mut thread_rng();
58+
59+
// We don't want to deal with sampling from 0-length axes in this test
60+
if m != 0 {
61+
let n_row_samples = Uniform::from(1..m+1).sample(&mut rng);
62+
if !sampling_works(&a, strategy.clone(), Axis(0), n_row_samples) {
63+
return false;
64+
}
65+
}
66+
67+
// We don't want to deal with sampling from 0-length axes in this test
68+
if n != 0 {
69+
let n_col_samples = Uniform::from(1..n+1).sample(&mut rng);
70+
if !sampling_works(&a, strategy, Axis(1), n_col_samples) {
71+
return false;
72+
}
73+
}
74+
75+
true
76+
}
77+
}
78+
79+
fn sampling_works(
80+
a: &Array2<f64>,
81+
strategy: SamplingStrategy,
82+
axis: Axis,
83+
n_samples: usize,
84+
) -> bool {
85+
let samples = a.sample_axis(axis, n_samples, strategy);
86+
samples
87+
.axis_iter(axis)
88+
.all(|lane| is_subset(&a, &lane, axis))
89+
}
90+
91+
// Check if, when sliced along `axis`, there is at least one lane in `a` equal to `b`
92+
fn is_subset(a: &Array2<f64>, b: &ArrayView1<f64>, axis: Axis) -> bool {
93+
a.axis_iter(axis).any(|lane| &lane == b)
94+
}
95+
96+
#[test]
97+
#[should_panic]
98+
fn sampling_without_replacement_from_a_zero_length_axis_should_panic() {
99+
let n = 5;
100+
let a = Array::random((0, n), Uniform::new(0., 2.));
101+
let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithoutReplacement);
102+
}
103+
104+
#[test]
105+
#[should_panic]
106+
fn sampling_with_replacement_from_a_zero_length_axis_should_panic() {
107+
let n = 5;
108+
let a = Array::random((0, n), Uniform::new(0., 2.));
109+
let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement);
110+
}

scripts/all-tests.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ cargo test --verbose --no-default-features
1313
cargo test --release --verbose --no-default-features
1414
cargo build --verbose --features "$FEATURES"
1515
cargo test --verbose --features "$FEATURES"
16+
cargo test --manifest-path=ndarray-rand/Cargo.toml --no-default-features --verbose
17+
cargo test --manifest-path=ndarray-rand/Cargo.toml --features quickcheck --verbose
1618
cargo test --manifest-path=serialization-tests/Cargo.toml --verbose
1719
cargo test --manifest-path=blas-tests/Cargo.toml --verbose
1820
CARGO_TARGET_DIR=target/ cargo test --manifest-path=numeric-tests/Cargo.toml --verbose

0 commit comments

Comments
 (0)