Skip to content

Commit a6fe82f

Browse files
authored
Merge pull request #916 from dam5h/permute-axis-not-sorting
Fix sort-axis example, add test to confirm original error and resolution
2 parents 07853e8 + c27626a commit a6fe82f

File tree

2 files changed

+72
-9
lines changed

2 files changed

+72
-9
lines changed

examples/sort-axis.rs

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,27 @@ where
102102
assert_eq!(axis_len, perm.indices.len());
103103
debug_assert!(perm.correct());
104104

105+
if self.is_empty() {
106+
return self;
107+
}
108+
105109
let mut result = Array::uninit(self.dim());
106110

107111
unsafe {
108112
// logically move ownership of all elements from self into result
109113
// the result realizes this ownership at .assume_init() further down
110114
let mut moved_elements = 0;
111-
for i in 0..axis_len {
112-
let perm_i = perm.indices[i];
113-
Zip::from(result.index_axis_mut(axis, perm_i))
114-
.and(self.index_axis(axis, i))
115-
.for_each(|to, from| {
116-
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
117-
moved_elements += 1;
118-
});
119-
}
115+
Zip::from(&perm.indices)
116+
.and(result.axis_iter_mut(axis))
117+
.for_each(|&perm_i, result_pane| {
118+
// possible improvement: use unchecked indexing for `index_axis`
119+
Zip::from(result_pane)
120+
.and(self.index_axis(axis, perm_i))
121+
.for_each(|to, from| {
122+
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
123+
moved_elements += 1;
124+
});
125+
});
120126
debug_assert_eq!(result.len(), moved_elements);
121127
// panic-critical begin: we must not panic
122128
// forget moved array elements but not its vec
@@ -129,6 +135,7 @@ where
129135
}
130136
}
131137
}
138+
132139
#[cfg(feature = "std")]
133140
fn main() {
134141
let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap();
@@ -143,5 +150,60 @@ fn main() {
143150
let c = strings.permute_axis(Axis(1), &perm);
144151
println!("{:?}", c);
145152
}
153+
146154
#[cfg(not(feature = "std"))]
147155
fn main() {}
156+
157+
#[cfg(test)]
158+
mod tests {
159+
use super::*;
160+
#[test]
161+
fn test_permute_axis() {
162+
let a = array![
163+
[107998.96, 1.],
164+
[107999.08, 2.],
165+
[107999.20, 3.],
166+
[108000.33, 4.],
167+
[107999.45, 5.],
168+
[107999.57, 6.],
169+
[108010.69, 7.],
170+
[107999.81, 8.],
171+
[107999.94, 9.],
172+
[75600.09, 10.],
173+
[75600.21, 11.],
174+
[75601.33, 12.],
175+
[75600.45, 13.],
176+
[75600.58, 14.],
177+
[109000.70, 15.],
178+
[75600.82, 16.],
179+
[75600.94, 17.],
180+
[75601.06, 18.],
181+
];
182+
183+
let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]);
184+
let b = a.permute_axis(Axis(0), &perm);
185+
assert_eq!(
186+
b,
187+
array![
188+
[75600.09, 10.],
189+
[75600.21, 11.],
190+
[75600.45, 13.],
191+
[75600.58, 14.],
192+
[75600.82, 16.],
193+
[75600.94, 17.],
194+
[75601.06, 18.],
195+
[75601.33, 12.],
196+
[107998.96, 1.],
197+
[107999.08, 2.],
198+
[107999.20, 3.],
199+
[107999.45, 5.],
200+
[107999.57, 6.],
201+
[107999.81, 8.],
202+
[107999.94, 9.],
203+
[108000.33, 4.],
204+
[108010.69, 7.],
205+
[109000.70, 15.],
206+
]
207+
);
208+
}
209+
}

scripts/all-tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ cargo test --manifest-path=ndarray-rand/Cargo.toml --no-default-features --verbo
1616
cargo test --manifest-path=ndarray-rand/Cargo.toml --features quickcheck --verbose
1717
cargo test --manifest-path=serialization-tests/Cargo.toml --verbose
1818
cargo test --manifest-path=blas-tests/Cargo.toml --verbose
19+
cargo test --examples
1920
CARGO_TARGET_DIR=target/ cargo test --manifest-path=numeric-tests/Cargo.toml --verbose
2021
([ "$CHANNEL" != "nightly" ] || cargo bench --no-run --verbose --features "$FEATURES")

0 commit comments

Comments
 (0)