8000 Make iterators covariant in element type by bluss · Pull Request #1417 · rust-ndarray/ndarray · GitHub
[go: up one dir, main page]

Skip to content

Make iterators covariant in element type #1417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Make iterators covariant in element type
The internal Baseiter type underlies most of the ndarray iterators, and
it used `*mut A` for element type A. Update it to use `NonNull<A>` which
behaves identically except it's guaranteed to be non-null and is
covariant w.r.t the parameter A.

Add compile test from the issue.

Fixes #1290
  • Loading branch information
bluss committed Aug 6, 2024
commit a9605dc251776985f1e6a49f12f91b61dbe9daa1
4 changes: 2 additions & 2 deletions src/impl_owned_array.rs
8000
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use alloc::vec::Vec;
use std::mem;
use std::mem::MaybeUninit;

#[allow(unused_imports)]
#[allow(unused_imports)] // Needed for Rust 1.64
use rawpointer::PointerExt;

use crate::imp_prelude::*;
Expand Down Expand Up @@ -907,7 +907,7 @@ where D: Dimension

// iter is a raw pointer iterator traversing the array in memory order now with the
// sorted axes.
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review note: it's evident that Baseiter is constructed from a NonNull everywhere, which means that its non-null requirement is easily fulfilled.

let mut dropped_elements = 0;

let mut last_ptr = data_ptr;
Expand Down
8 changes: 4 additions & 4 deletions src/impl_views/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}
}

Expand All @@ -209,7 +209,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}
}

Expand All @@ -220,7 +220,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}

#[inline]
Expand Down Expand Up @@ -262,7 +262,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}

#[inline]
Expand Down
5 changes: 2 additions & 3 deletions src/iterators/into_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ impl<A, D> IntoIter<A, D>
where D: Dimension
{
/// Create a new by-value iterator that consumes `array`
pub(crate) fn new(mut array: Array<A, D>) -> Self
pub(crate) fn new(array: Array<A, D>) -> Self
{
unsafe {
let array_head_ptr = array.ptr;
let ptr = array.as_mut_ptr();
let mut array_data = array.data;
let data_len = array_data.release_all_elements();
debug_assert!(data_len >= array.dim.size());
let has_unreachable_elements = array.dim.size() != data_len;
let inner = Baseiter::new(ptr, array.dim, array.strides);
let inner = Baseiter::new(array_head_ptr, array.dim, array.strides);

IntoIter {
array_data,
Expand Down
23 changes: 14 additions & 9 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ use alloc::vec::Vec;
use std::iter::FromIterator;
use std::marker::PhantomData;
use std::ptr;
use std::ptr::NonNull;

#[allow(unused_imports)] // Needed for Rust 1.64
use rawpointer::PointerExt;

use crate::Ix1;

Expand All @@ -38,7 +42,7 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
#[derive(Debug)]
pub struct Baseiter<A, D>
{
ptr: *mut A,
ptr: NonNull<A>,
dim: D,
strides: D,
index: Option<D>,
Expand All @@ -50,7 +54,7 @@ impl<A, D: Dimension> Baseiter<A, D>
/// to be correct to avoid performing an unsafe pointer offset while
/// iterating.
#[inline]
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D>
pub unsafe fn new(ptr: NonNull<A>, len: D, stride: D) -> Baseiter<A, D>
{
Baseiter {
ptr,
Expand All @@ -74,7 +78,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
};
let offset = D::stride_offset(&index, &self.strides);
self.index = self.dim.next_for(index);
unsafe { Some(self.ptr.offset(offset)) }
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
}

fn size_hint(&self) -> (usize, Option<usize>)
Expand All @@ -99,7 +103,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
let mut i = 0;
let i_end = len - elem_index;
while i < i_end {
accum = g(accum, row_ptr.offset(i as isize * stride));
accum = g(accum, row_ptr.offset(i as isize * stride).as_ptr());
i += 1;
}
}
Expand Down Expand Up @@ -140,12 +144,12 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
Some(ix) => ix,
};
self.dim[0] -= 1;
let offset = <_>::stride_offset(&self.dim, &self.strides);
let offset = Ix1::stride_offset(&self.dim, &self.strides);
if index == self.dim {
self.index = None;
}

unsafe { Some(self.ptr.offset(offset)) }
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
}

fn nth_back(&mut self, n: usize) -> Option<*mut A>
Expand All @@ -154,11 +158,11 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
let len = self.dim[0] - index[0];
if n < len {
self.dim[0] -= n + 1;
let offset = <_>::stride_offset(&self.dim, &self.strides);
let offset = Ix1::stride_offset(&self.dim, &self.strides);
if index == self.dim {
self.index = None;
}
unsafe { Some(self.ptr.offset(offset)) }
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
} else {
self.index = None;
None
Expand All @@ -178,7 +182,8 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
accum = g(
accum,
self.ptr
.offset(Ix1::stride_offset(&self.dim, &self.strides)),
.offset(Ix1::stride_offset(&self.dim, &self.strides))
.as_ptr(),
);
}
}
Expand Down
34 changes: 31 additions & 3 deletions tests/iterators.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#![allow(
clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names
)]
#![allow(clippy::deref_addrof, clippy::unreadable_literal)]

use ndarray::prelude::*;
use ndarray::{arr3, indices, s, Slice, Zip};
Expand Down Expand Up @@ -1055,3 +1053,33 @@ impl Drop for DropCount<'_>
self.drops.set(self.drops.get() + 1);
}
}

#[test]
fn test_impl_iter_compiles()
{
// Requires that the iterators are covariant in the element type

// base case: std
fn slice_iter_non_empty_indices<'s, 'a>(array: &'a Vec<&'s str>) -> impl Iterator<Item = usize> + 'a
{
array
.iter()
.enumerate()
.filter(|(_index, elem)| !elem.is_empty())
.map(|(index, _elem)| index)
}

let _ = slice_iter_non_empty_indices;

// ndarray case
fn array_iter_non_empty_indices<'s, 'a>(array: &'a Array<&'s str, Ix1>) -> impl Iterator<Item = usize> + 'a
{
array
.iter()
.enumerate()
.filter(|(_index, elem)| !elem.is_empty())
.map(|(index, _elem)| index)
}

let _ = array_iter_non_empty_indices;
}
0