8000 Add support for inserting new axes while slicing by jturner314 · Pull Request #570 · rust-ndarray/ndarray · GitHub
[go: up one dir, main page]

Skip to content

Add support for inserting new axes while slicing #570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Mar 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
24a3299
Rename SliceOrIndex to AxisSliceInfo
jturner314 Dec 9, 2018
6a16b88
Switch from Dimension::SliceArg to CanSlice trait
jturner314 Dec 9, 2018
546b69c
Add support for inserting new axes while slicing
jturner314 Dec 9, 2018
6e335ca
Rename SliceInfo generic params to Din and Dout
jturner314 Dec 17, 2018
d6b9cb0
Improve code style
jturner314 Dec 17, 2018
438d69a
Derive Clone, Copy, and Debug for NewAxis
jturner314 Dec 17, 2018
6050df3
Use stringify! for string literal of type name
jturner314 Dec 18, 2018
8d45268
Make step_by panic for variants other than Slice
jturner314 Dec 18, 2018
1d15275
Add DimAdd trait
jturner314 Dec 18, 2018
41cc4a1
Replace SliceNextIn/OutDim with SliceArg trait
jturner314 Dec 18, 2018
c66ad8c
Combine DimAdd impls for Ix0
jturner314 Feb 7, 2021
7776bfc
Implement CanSlice<IxDyn> for [AxisSliceInfo]
jturner314 Feb 14, 2021
ab79d28
Change SliceInfo to be repr(transparent)
jturner314 Feb 15, 2021
615113e
Add debug assertions to SliceInfo::new_unchecked
jturner314 Feb 15, 2021
e66e3c8
Fix safety of SliceInfo::new
jturner314 Feb 15, 2021
3ba6ceb
Add some impls of TryFrom for SliceInfo
jturner314 Feb 15, 2021
815e708
Make slice_move not call slice_collapse
jturner314 Feb 16, 2021
25a7bb0
Make slice_collapse return Err(_) for NewAxis
jturner314 Feb 16, 2021
5202a50
Expose CanSlice trait in public API
jturner314 Feb 16, 2021
319701d
Expose MultiSlice trait in public API
jturner314 Feb 16, 2021
d5d6482
Add DimAdd bounds to Dimension trait
jturner314 Feb 16, 2021
9614b13
Revert "Make slice_collapse return Err(_) for NewAxis"
jturner314 Feb 17, 2021
61cf7c0
Make slice_collapse panic on NewAxis
jturner314 Feb 17, 2021
91dbf3f
Rename DimAdd::Out to DimAdd::Output
jturner314 Feb 17, 2021
5dc77bd
Rename SliceArg to SliceNextDim
jturner314 Feb 17, 2021
87515c6
Rename CanSlice to SliceArg
jturner314 Feb 17, 2021
c4efbbf
Rename MultiSlice to MultiSliceArg
jturner314 Feb 17, 2021
7506f90
Clarify docs of .slice_collapse()
jturner314 Feb 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add some impls of TryFrom for SliceInfo
  • Loading branch information
jturner314 authored and bluss committed Mar 12, 2021
commit 3ba6ceb8cd0463cb5164e19b295a26b376cd5261
114 changes: 98 additions & 16 deletions src/slice.rs
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
use crate::dimension::slices_intersect;
use crate::error::{ErrorKind, ShapeError};
use crate::{ArrayViewMut, DimAdd, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
use alloc::vec::Vec;
use std::convert::TryFrom;
use std::fmt;
use std::marker::PhantomData;
use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
Expand Down Expand Up @@ -402,6 +404,24 @@ where
}
}

fn check_dims_for_sliceinfo<Din, Dout>(indices: &[AxisSliceInfo]) -> Result<(), ShapeError>
where
Din: Dimension,
Dout: Dimension,
{
if let Some(in_ndim) = Din::NDIM {
if in_ndim != indices.in_ndim() {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
}
}
if let Some(out_ndim) = Dout::NDIM {
if out_ndim != indices.out_ndim() {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
}
Ok(())
}

impl<T, Din, Dout> SliceInfo<T, Din, Dout>
where
T: AsRef<[AxisSliceInfo]>,
Copy link
Member
@bluss bluss Feb 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This raises my AsRef concern - an unfaithful/evil impl of AsRef can return different slices at different times, we can't trust AsRef for unsafe code blocks - I think it means the check we do in SliceInfo::new is not effectively guarding the invariant we need because .as_ref() can later return a different slice. 🙁 This is one of my pet peeves - can't trust AsRef for unsafe code, unfortunately.

I wonder if and I hope that we haven't talked about this before? Maybe there's a reason this is not a concern?

Copy link
Member Author
@jturner314 jturner314 Feb 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good point. I've changed SliceInfo::new to be unsafe and added TryFrom implementations for a whitelisted set of types (slices, Vec, and fixed-length arrays) which we should be able to depend on having correct AsRef implementations.

It would be nice to make CanSlice and the SliceInfo constructors safe, i.e. move the responsibility for the necessary checks to the slicing methods. Modifying the .slice*() methods would be fairly simple. (The primary change would be to make sure that they call .as_ref() only once.) However, modifying the implementation of multi-slicing would be trickier.

Expand All @@ -424,12 +444,8 @@ where
out_dim: PhantomData<Dout>,
) -> SliceInfo<T, Din, Dout> {
if cfg!(debug_assertions) {
if let Some(in_ndim) = Din::NDIM {
assert_eq!(in_ndim, indices.as_ref().in_ndim());
}
if let Some(out_ndim) = Dout::NDIM {
assert_eq!(out_ndim, indices.as_ref().out_ndim());
}
check_dims_for_sliceinfo::<Din, Dout>(indices.as_ref())
.expect("`Din` and `Dout` must be consistent with `indices`.");
}
SliceInfo {
in_dim,
Expand All @@ -449,21 +465,14 @@ where
///
/// Errors if `Din` or `Dout` is not consistent with `indices`.
///
/// For common types, a safe alternative is to use `TryFrom` instead.
///
/// # Safety
///
/// The caller must ensure `indices.as_ref()` always returns the same value
/// when called multiple times.
pub unsafe fn new(indices: T) -> Result<SliceInfo<T, Din, Dout>, ShapeError> {
if let Some(in_ndim) = Din::NDIM {
if in_ndim != indices.as_ref().in_ndim() {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
}
}
if let Some(out_ndim) = Dout::NDIM {
if out_ndim != indices.as_ref().out_ndim() {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
}
}
check_dims_for_sliceinfo::<Din, Dout>(indices.as_ref())?;
Ok(SliceInfo {
in_dim: PhantomData,
out_dim: PhantomData,
Expand Down Expand Up @@ -508,6 +517,79 @@ where
}
}

impl<'a, Din, Dout> TryFrom<&'a [AxisSliceInfo]> for &'a SliceInfo<[AxisSliceInfo], Din, Dout>
where
Din: Dimension,
Dout: Dimension,
{
type Error = ShapeError;

fn try_from(
indices: &'a [AxisSliceInfo],
) -> Result<&'a SliceInfo<[AxisSliceInfo], Din, Dout>, ShapeError> {
check_dims_for_sliceinfo::<Din, Dout>(indices)?;
unsafe {
// This is okay because we've already checked the correctness of
// `Din` and `Dout`, and the only non-zero-sized member of
// `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din,
// Dout>` should have the same bitwise representation as
// `&[AxisSliceInfo]`.
Ok(&*(indices as *const [AxisSliceInfo]
as *const SliceInfo<[AxisSliceInfo], Din, Dout>))
}
}
}

impl<Din, Dout> TryFrom<Vec<AxisSliceInfo>> for SliceInfo<Vec<AxisSliceInfo>, Din, Dout>
where
Din: Dimension,
Dout: Dimension,
{
type Error = ShapeError;

fn try_from(
indices: Vec<AxisSliceInfo>,
) -> Result<SliceInfo<Vec<AxisSliceInfo>, Din, Dout>, ShapeError> {
unsafe {
// This is okay because `Vec` always returns the same value for
// `.as_ref()`.
Self::new(indices)
}
}
}

macro_rules! impl_tryfrom_array_for_sliceinfo {
($len:expr) => {
impl<Din, Dout> TryFrom<[AxisSliceInfo; $len]>
for SliceInfo<[AxisSliceInfo; $len], Din, Dout>
where
Din: Dimension,
Dout: Dimension,
{
type Error = ShapeError;

fn try_from(
indices: [AxisSliceInfo; $len],
) -> Result<SliceInfo<[AxisSliceInfo; $len], Din, Dout>, ShapeError> {
unsafe {
// This is okay because `[AxisSliceInfo; N]` always returns
// the same value for `.as_ref()`.
Self::new(indices)
}
}
}
};
}
impl_tryfrom_array_for_sliceinfo!(0);
impl_tryfrom_array_for_sliceinfo!(1);
impl_tryfrom_array_for_sliceinfo!(2);
impl_tryfrom_array_for_sliceinfo!(3);
impl_tryfrom_array_for_sliceinfo!(4);
impl_tryfrom_array_for_sliceinfo!(5);
impl_tryfrom_array_for_sliceinfo!(6);
impl_tryfrom_array_for_sliceinfo!(7);
impl_tryfrom_array_for_sliceinfo!(8);

impl<T, Din, Dout> AsRef<[AxisSliceInfo]> for SliceInfo<T, Din, Dout>
where
T: AsRef<[AxisSliceInfo]>,
Expand Down
65 changes: 29 additions & 36 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use ndarray::prelude::*;
use ndarray::{arr3, rcarr2};
use ndarray::indices;
use ndarray::{AxisSliceInfo, Slice, SliceInfo};
use std::convert::TryFrom;

macro_rules! assert_panics {
($body:expr) => {
Expand Down Expand Up @@ -216,15 +217,13 @@ fn test_slice_dyninput_array_fixed() {
#[test]
fn test_slice_array_dyn() {
let mut arr = Array3::<f64>::zeros((5, 2, 5));
let info = &unsafe {
SliceInfo::<_, Ix3, IxDyn>::new([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
AxisSliceInfo::from(..).step_by(2),
])
.unwrap()
};
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
AxisSliceInfo::from(..).step_by(2),
])
.unwrap();
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
Expand All @@ -234,15 +233,13 @@ fn test_slice_array_dyn() {
#[test]
fn test_slice_dyninput_array_dyn() {
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
let info = &unsafe {
SliceInfo::<_, Ix3, IxDyn>::new([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
AxisSliceInfo::from(..).step_by(2),
])
.unwrap()
};
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
AxisSliceInfo::from(..).step_by(2),
])
.unwrap();
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
Expand All @@ -252,15 +249,13 @@ fn test_slice_dyninput_array_dyn() {
#[test]
fn test_slice_dyninput_vec_fixed() {
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
let info = &unsafe {
SliceInfo::<_, Ix3, Ix3>::new(vec![
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
AxisSliceInfo::from(..).step_by(2),
])
.unwrap()
};
let info = &SliceInfo::<_, Ix3, Ix3>::try_from(vec![
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
AxisSliceInfo::from(..).step_by(2),
])
.unwrap();
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
Expand All @@ -270,15 +265,13 @@ fn test_slice_dyninput_vec_fixed() {
#[test]
fn test_slice_dyninput_vec_dyn() {
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
let info = &unsafe {
SliceInfo::<_, Ix3, IxDyn>::new(vec![
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
AxisSliceInfo::from(..).step_by(2),
])
.unwrap()
};
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from(vec![
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
AxisSliceInfo::from(..).step_by(2),
])
.unwrap();
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
Expand Down
3 changes: 2 additions & 1 deletion tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ fn scaled_add_2() {
fn scaled_add_3() {
use approx::assert_relative_eq;
use ndarray::{SliceInfo, AxisSliceInfo};
use std::convert::TryFrom;

let beta = -2.3;
let sizes = vec![
Expand Down Expand Up @@ -595,7 +596,7 @@ fn scaled_add_3() {

{
let mut av = a.slice_mut(s![..;s1, ..;s2]);
let c = c.slice(&unsafe { SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap() });
let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap());

let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
answerv += &(beta * &c);
Expand Down
0