E5F1 Make stack and concatenate compliant with numpy naming. by andrei-papou · Pull Request #850 · rust-ndarray/ndarray · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane
pub use crate::arraytraits::AsArray;
pub use crate::linalg_traits::{LinalgScalar, NdFloat};

#[allow(deprecated)]
pub use crate::stacking::{concatenate, stack, stack_new_axis};

pub use crate::impl_views::IndexLonger;
Expand Down
93 changes: 47 additions & 46 deletions src/stacking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,42 @@
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;

/// Stack arrays along the new axis.
///
/// ***Errors*** if the arrays have mismatching shapes.
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack(Axis(0), &[a.view(), a.view()])
/// == Ok(arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]]))
/// );
/// # }
/// ```
pub fn stack<A, D>(
axis: Axis,
arrays: &[ArrayView<A, D>],
) -> Result<Array<A, D::Larger>, ShapeError>
where
A: Copy,
D: Dimension,
D::Larger: RemoveAxis,
{
stack_new_axis(axis, arrays)
}

/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
Expand All @@ -17,23 +53,19 @@ use crate::imp_prelude::*;
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, stack};
/// use ndarray::{arr2, Axis, concatenate};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack(Axis(0), &[a.view(), a.view()])
/// concatenate(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate` function instead"
)]
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
Expand Down Expand Up @@ -77,35 +109,6 @@ where
Ok(res)
}

/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, concatenate};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// concatenate(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
#[allow(deprecated)]
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
{
stack(axis, arrays)
}

/// Stack arrays along the new axis.
///
/// ***Errors*** if the arrays have mismatching shapes.
Expand Down Expand Up @@ -173,7 +176,7 @@ where
Ok(res)
}

/// Concatenate arrays along the given axis.
/// Stack arrays along the new axis.
///
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
Expand All @@ -183,25 +186,23 @@ where
/// ***Panics*** if the `stack` function would return an error.
///
/// ```
/// use ndarray::{arr2, stack, Axis};
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack![Axis(0), a, a]
/// == arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]])
/// == arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]])
/// );
/// # }
/// ```
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate!` macro instead"
)]
#[macro_export]
macro_rules! stack {
($axis:expr, $( $array:expr ),+ ) => {
Expand Down
35 changes: 7 additions & 28 deletions tests/stacking.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,7 @@
#![allow(deprecated)]

use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1};

#[test]
fn concatenating() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));

let c = stack![Axis(0), a, b];
assert_eq!(
c,
arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
);

let d = stack![Axis(0), a.row(0), &[9., 9.]];
assert_eq!(d, aview1(&[2., 2., 9., 9.]));

let res = ndarray::stack(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack(Axis(2), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::stack(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);

let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
Expand All @@ -52,16 +28,19 @@ fn concatenating() {
#[test]
fn stacking() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack_new_axis(Axis(0), &[a.view(), a.view()]).unwrap();
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));

let c = stack![Axis(0), a, a];
assert_eq!(c, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));

let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]);
let res = ndarray::stack_new_axis(Axis(1), &[a.view(), c.view()]);
let res = ndarray::stack(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack_new_axis(Axis(3), &[a.view(), a.view()]);
let res = ndarray::stack(Axis(3), &[a.view(), a.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), &[]);
let res: Result<Array2<f64>, _> = ndarray::stack::<_, Ix1>(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}
0