diff --git a/src/impl_expand.rs b/src/impl_expand.rs new file mode 100644 index 000000000..4a6cc328e --- /dev/null +++ b/src/impl_expand.rs @@ -0,0 +1,233 @@ +// Copyright 2021 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::imp_prelude::*; + +use crate::data_traits::RawDataSubst; +use crate::math_cell::MathCell; + +use std::cell::Cell; +use std::mem; +use std::num::Wrapping; +use std::num; +use std::ptr::NonNull; + +use num_complex::Complex; + + +pub unsafe trait EqualRepresentation { } + +unsafe impl EqualRepresentation for T { } + +macro_rules! unsafe_impl_trivial_transmute_t { + ($($from:ty => $to:ty,)+) => { + $( +unsafe impl EqualRepresentation<$to> for $from { } + )+ + + } +} + +macro_rules! unsafe_impl_trivial_transmute_bidir { + ($(($from:ty) <=> $to:ty,)+) => { + $( +unsafe impl EqualRepresentation<$to> for $from { } +unsafe impl EqualRepresentation<$from> for $to { } + )+ + + } +} + +macro_rules! unsafe_impl_trivial_transmute { + ($($from:ty => $to:ty,)+) => { + $( +unsafe impl EqualRepresentation<$to> for $from { } + )+ + + } +} + +unsafe_impl_trivial_transmute_t! { + T => [T; 1], + [T; 1] => T, + // transmute from cell to T, but not the other way around + Cell => T, + MathCell => T, + *mut T => *const T, + *const T => *mut T, + NonNull => *mut T, + NonNull => *const T, + Wrapping => T, + T => Wrapping, +} + +unsafe_impl_trivial_transmute_bidir! { + (usize) <=> isize, + (u128) <=> i128, + (u64) <=> i64, + (u32) <=> i32, + (u16) <=> i16, + (u8) <=> i8, +} + +unsafe_impl_trivial_transmute! { + // only from nonzero, not to + num::NonZeroU8 => u8, + num::NonZeroI8 => i8, + num::NonZeroU16 => u16, + num::NonZeroI16 => i16, + num::NonZeroU32 => u32, + num::NonZeroI32 => i32, + num::NonZeroU64 => u64, + num::NonZeroI64 => i64, + num::NonZeroU128 => u128, + num::NonZeroI128 => i128, + num::NonZeroUsize => usize, + num::NonZeroIsize => isize, +} + + +pub unsafe trait MultiElement : MultiElementExtended<::Elem> { + type Elem; +} + +pub unsafe trait MultiElementExtended { + const LEN: usize; +} + +unsafe impl MultiElement for [A; N] { + type Elem = A; +} + +unsafe impl MultiElementExtended for [A; N] { + const LEN: usize = N; +} + +unsafe impl MultiElement for Complex { + type Elem = A; +} + +unsafe impl MultiElementExtended for Complex { + const LEN: usize = 2; +} + +macro_rules! multi_elem { + ($($from:ty => $to:ty,)+) => { + $( +unsafe impl MultiElementExtended<$to> for $from { + const LEN: usize = mem::size_of::<$from>() / mem::size_of::<$to>(); +} + )+ + + } +} + +multi_elem! { + usize => u8, + isize => i8, + u128 => i64, + u128 => u64, + u128 => i32, + u128 => u32, + u128 => i16, + u128 => u16, + u128 => i8, + u128 => u8, + u64 => i32, + u64 => u32, + u64 => i16, + u64 => u16, + u64 => i8, + u64 => u8, + u32 => i16, + u32 => u16, + u32 => i8, + u32 => u8, + u16 => i8, + u16 => u8, + i128 => i64, + i128 => u64, + i128 => i32, + i128 => u32, + i128 => i16, + i128 => u16, + i128 => i8, + i128 => u8, + i64 => i32, + i64 => u32, + i64 => i16, + i64 => u16, + i64 => i8, + i64 => u8, + i32 => i16, + i32 => u16, + i32 => i8, + i32 => u8, + i16 => i8, + i16 => u8, +} + +impl<'a, A, D> ArrayView<'a, A, D> +where + D: Dimension, + A: MultiElement, +{ + /// + /// Note: expanding a zero-element array, `[A; 0]`, leads to a new axis of length zero, + /// i.e. the result is an empty array view. + /// + /// **Panics** if the product of non-zero axis lengths overflows `isize`. + pub fn expand(self, new_axis: Axis) -> ArrayView<'a, A::Elem, D::Larger> { + self.expand_to::(new_axis) + } +} + +impl<'a, A, D> ArrayView<'a, A, D> +where + D: Dimension, +{ + /// + /// Note: expanding a zero-element array, `[A; 0]`, leads to a new axis of length zero, + /// i.e. the result is an empty array view. + /// + /// **Panics** if the product of non-zero axis lengths overflows `isize`. + pub fn expand_to(self, new_axis: Axis) -> ArrayView<'a, T, D::Larger> + where A: MultiElementExtended, + { + assert_eq!(mem::size_of::(), mem::size_of::() * A::LEN); + + let mut dim = self.dim.insert_axis(new_axis); + let mut strides = self.strides.insert_axis(new_axis); + + // Double the strides. In the zero-sized element case and for axes of + // length <= 1, we leave the strides as-is to avoid possible overflow. + let len = A::LEN as isize; + if mem::size_of::() != 0 { + for ax in 0..strides.ndim() { + if Axis(ax) == new_axis { + continue; + } + if dim[ax] > 1 { + strides[ax] = ((strides[ax] as isize) * len) as usize; + } + } + } + dim[new_axis.index()] = A::LEN; + + // TODO nicer assertion + crate::dimension::size_of_shape_checked(&dim).unwrap(); + + // safe because + // size still fits in isize; + // new strides are adapted to new element type, inside the same allocation. + unsafe { + ArrayBase::from_data_ptr(self.data.data_subst(), self.ptr.cast()) + .with_strides_dim(strides, dim) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index f773226f5..701caf94e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,9 @@ -// Copyright 2014-2020 bluss and ndarray developers. +// Copyright 2014-2021 ndarray developers. +// Main authors: +// +// Ulrik Sverdrup "bluss" +// Jim Turner "jturner314" +// and many others // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -1515,6 +1520,7 @@ mod impl_internal_constructors; mod impl_constructors; mod impl_methods; +mod impl_expand; mod impl_owned_array; mod impl_special_element_types; diff --git a/tests/expand.rs b/tests/expand.rs new file mode 100644 index 000000000..2a83a4739 --- /dev/null +++ b/tests/expand.rs @@ -0,0 +1,98 @@ + +use ndarray::prelude::*; + +use ndarray::stack; + +use num_complex::Complex; + +fn cx(re: T, im: T) -> Complex { + Complex::new(re, im) +} + +#[test] +fn test_expand_from_zero() { + let a = Array::from_elem((), [[[1, 2], [3, 4], [5, 6]], + [[11, 12], [13, 14], [15, 16]]]); + let av = a.view(); + println!("{:?}", av); + let av = av.expand(Axis(0)); + println!("{:?}", av); + let av = av.expand(Axis(1)); + println!("{:?}", av); + let av = av.expand(Axis(2)); + println!("{:?}", av); + assert!(av.is_standard_layout()); + assert_eq!(av, av.to_owned()); + + let av = a.view(); + println!("{:?}", av); + let av = av.expand(Axis(0)); + println!("{:?}", av); + let av = av.expand(Axis(0)); + println!("{:?}", av); + let av = av.expand(Axis(0)); + println!("{:?}", av); + assert!(av.t().is_standard_layout()); + assert_eq!(av, av.to_owned()); +} + +#[test] +fn test_expand_zero() { + let a = Array::from_elem((3, 4), [0.; 0]); + + for ax in 0..=2 { + let mut new_shape = [3, 4, 4]; + new_shape[1] = if ax == 0 { 3 } else { 4 }; + new_shape[ax] = 0; + let av = a.view(); + let av = av.expand(Axis(ax)); + assert_eq!(av.shape(), &new_shape); + } +} + +#[test] +fn test_expand1() { + let a = Array::from_elem((3, 3), [1, 2, 3]); + println!("{:?}", a); + let b = a.view().expand(Axis(2)); + println!("{:?}", b); + let b = a.view().expand(Axis(1)); + println!("{:?}", b); + let b = a.view().expand(Axis(0)); + println!("{:?}", b); +} + + +#[test] +fn test_complex() { + let a = arr2(&[[cx(3., 4.), cx(2., 0.)], [cx(0., -2.), cx(3., 0.)]]); + let av = a.view(); + for ax in 0..=2 { + let av = av.expand(Axis(ax)); + let answer = stack![Axis(ax), a.mapv(|z| z.re), a.mapv(|z| z.im)]; + assert_eq!(av, answer); + } +} + + +#[test] +fn test_u128() { + let a = array![[1, 2, 3u128], [4, 5, 6]]; + let av = a.view(); + println!("{:?}", av); + for ax in 0..=2 { + let av = av.expand_to::(Axis(ax)); + println!("{:?}", av); + } +} + +#[test] +fn test_u64() { + let a = array![[1, 2, 3u64], [4, 5, 6]]; + let av = a.view(); + println!("{:?}", av); + for ax in 0..=2 { + let av = av.expand_to::(Axis(ax)); + println!("{:?}", av); + } +}