8000 Merge pull request #814 from rust-ndarray/zip-collect-drop · rust-ndarray/ndarray@adef586 · GitHub
[go: up one dir, main page]

Skip to content

Commit adef586

Browse files
authored
Merge pull request #814 from rust-ndarray/zip-collect-drop
Implement Zip::apply_collect for non-Copy elements too
2 parents 3ea6861 + 624fd75 commit adef586

File tree

4 files changed

+267
-7
lines changed

4 files changed

+267
-7
lines changed

benches/bench1.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,38 @@ fn add_2d_alloc_zip_collect(bench: &mut test::Bencher) {
286286
});
287287
}
288288

289+
#[bench]
290+
fn vec_string_collect(bench: &mut test::Bencher) {
291+
let v = vec![""; 10240];
292+
bench.iter(|| {
293+
v.iter().map(|s| s.to_owned()).collect::<Vec<_>>()
294+
});
295+
}
296+
297+
#[bench]
298+
fn array_string_collect(bench: &mut test::Bencher) {
299+
let v = Array::from(vec![""; 10240]);
300+
bench.iter(|| {
301+
Zip::from(&v).apply_collect(|s| s.to_owned())
302+
});
303+
}
304+
305+
#[bench]
306+
fn vec_f64_collect(bench: &mut test::Bencher) {
307+
let v = vec![1.; 10240];
308+
bench.iter(|| {
309+
v.iter().map(|s| s + 1.).collect::<Vec<_>>()
310+
});
311+
}
312+
313+
#[bench]
314+
fn array_f64_collect(bench: &mut test::Bencher) {
315+
let v = Array::from(vec![1.; 10240]);
316+
bench.iter(|| {
317+
Zip::from(&v).apply_collect(|s| s + 1.)
318+
});
319+
}
320+
289321

290322
#[bench]
291323
fn add_2d_assign_ops(bench: &mut test::Bencher) {

src/zip/mod.rs

Lines changed: 19 additions & 7 deletions
< 10000 td data-grid-cell-id="diff-b85e1d27bc138ae483f91675ff4ba8b3a6bb6c634d0a6e939e37f69f16fd45b3-1082-1093-0" data-selected="false" role="gridcell" style="background-color:var(--diffBlob-additionNum-bgColor, var(--diffBlob-addition-bgColor-num));text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative left-side">
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#[macro_use]
1010
mod zipmacro;
11+
mod partial_array;
1112

1213
use std::mem::MaybeUninit;
1314

@@ -20,6 +21,8 @@ use crate::NdIndex;
2021
use crate::indexes::{indices, Indices};
2122
use crate::layout::{CORDER, FORDER};
2223

24+
use partial_array::PartialArray;
25+
2326
/// Return if the expression is a break value.
2427
macro_rules! fold_while {
2528
($e:expr) => {
@@ -195,6 +198,7 @@ pub trait NdProducer {
195198
fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
196199
where
197200
Self: Sized;
201+
198202
private_decl! {}
199203
}
200204

@@ -1070,16 +1074,24 @@ macro_rules! map_impl {
10701074
/// inputs.
10711075
///
10721076
/// If all inputs are c- or f-order respectively, that is preserved in the output.
1073-
///
1074-
/// Restricted to functions that produce copyable results for technical reasons; other
1075-
/// cases are not yet implemented.
10761077
pub fn apply_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D>
1077-
where R: Copy,
10781078
{
1079-
// To support non-Copy elements, implementation of dropping partial array (on
1080-
// panic) is needed
1079+
// Make uninit result
10811080
let mut output = self.uninitalized_for_current_layout::<R>();
1082-
self.apply_assign_into(&mut output, f);
1081+
if !std::mem::needs_drop::<R>() {
1082+
// For elements with no drop glue, just overwrite into the array
1083+
self.apply_assign_into(&mut output, f);
1084+
} else {
1085+
// For generic elements, use a proxy that counts the number of filled elements,
1086+
// and can drop the right number of elements on unwinding
1087+
unsafe {
1088+
PartialArray::scope(output.view_mut(), move |partial| {
1089+
debug_assert_eq!(partial.layout().tendency() >= 0, self.layout_tendency >= 0);
1090+
self.apply_assign_into(partial, f);
1091+
});
1092+
}
1093+
}
1094+
10831095
unsafe {
10841096
output.assume_init()
10851097
}

src/zip/partial_array.rs

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright 2020 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use crate::imp_prelude::*;
10+
use crate::{
11+
AssignElem,
12+
Layout,
13+
NdProducer,
14+
Zip,
15+
FoldWhile,
16+
};
17+
18+
use std::cell::Cell;
19+
use std::mem;
20+
use std::mem::MaybeUninit;
21+
use std::ptr;
22+
23+
/// An assignable element reference that increments a counter when assigned
24+
pub(crate) struct ProxyElem<'a, 'b, A> {
25+
item: &'a mut MaybeUninit<A>,
26+
filled: &'b Cell<usize>
27+
}
28+
29+
impl<'a, 'b, A> AssignElem<A> for ProxyElem<'a, 'b, A> {
30+
fn assign_elem(self, item: A) {
31+
self.filled.set(self.filled.get() + 1);
32+
*self.item = MaybeUninit::new(item);
33+
}
34+
}
35+
36+
/// Handles progress of assigning to a part of an array, for elements that need
37+
/// to be dropped on unwinding. See Self::scope.
38+
pub(crate) struct PartialArray<'a, 'b, A, D>
39+
where D: Dimension
40+
{
41+
data: ArrayViewMut<'a, MaybeUninit<A>, D>,
42+
filled: &'b Cell<usize>,
43+
}
44+
45+
impl<'a, 'b, A, D> PartialArray<'a, 'b, A, D>
46+
where D: Dimension
47+
{
48+
/// Create a temporary PartialArray that wraps the array view `data`;
49+
/// if the end of the scope is reached, the partial array is marked complete;
50+
/// if execution unwinds at any time before them, the elements written until then
51+
/// are dropped.
52+
///
53+
/// Safety: the caller *must* ensure that elements will be written in `data`'s preferred order.
54+
/// PartialArray can not handle arbitrary writes, only in the memory order.
55+
pub(crate) unsafe fn scope(data: ArrayViewMut<'a, MaybeUninit<A>, D>,
56+
scope_fn: impl FnOnce(&mut PartialArray<A, D>))
57+
{
58+
let filled = Cell::new(0);
59+
let mut partial = PartialArray::new(data, &filled);
60+
scope_fn(&mut partial);
61+
filled.set(0); // mark complete
62+
}
63+
64+
unsafe fn new(data: ArrayViewMut<'a, MaybeUninit<A>, D>,
65+
filled: &'b Cell<usize>) -> Self
66+
{
67+
debug_assert_eq!(filled.get(), 0);
68+
Self { data, filled }
69+
}
70+
}
71+
72+
impl<'a, 'b, A, D> Drop for PartialArray<'a, 'b, A, D>
73+
where D: Dimension
74+
{
75+
fn drop(&mut self) {
76+
if !mem::needs_drop::<A>() {
77+
return;
78+
}
79+
80+
let mut count = self.filled.get();
81+
if count == 0 {
82+
return;
83+
}
84+
85+
Zip::from(self).fold_while((), move |(), elt| {
86+
if count > 0 {
87+
count -= 1;
88+
unsafe {
89+
ptr::drop_in_place::<A>(elt.item.as_mut_ptr());
90+
}
91+
FoldWhile::Continue(())
92+
} else {
93+
FoldWhile::Done(())
94+
}
95+
});
96+
}
97+
}
98+
99+
impl<'a: 'c, 'b: 'c, 'c, A, D: Dimension> NdProducer for &'c mut PartialArray<'a, 'b, A, D> {
100+
// This just wraps ArrayViewMut as NdProducer and maps the item
101+
type Item = ProxyElem<'a, 'b, A>;
102+
type Dim = D;
103+
type Ptr = *mut MaybeUninit<A>;
104+
type Stride = isize;
105+
106+
private_impl! {}
107+
fn raw_dim(&self) -> Self::Dim {
108+
self.data.raw_dim()
109+
}
110+
111+
fn equal_dim(&self, dim: &Self::Dim) -> bool {
112+
self.data.equal_dim(dim)
113+
}
114+
115+
fn as_ptr(&self) -> Self::Ptr {
116+
NdProducer::as_ptr(&self.data)
117+
}
118+
119+
fn layout(&self) -> Layout {
120+
self.data.layout()
121+
}
122+
123+
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
124+
ProxyElem { filled: self.filled, item: &mut *ptr }
125+
}
126+
127+
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
128+
self.data.uget_ptr(i)
129+
}
130+
131+
fn stride_of(&self, axis: Axis) -> Self::Stride {
132+
self.data.stride_of(axis)
133+
}
134+
135+
#[inline(always)]
136+
fn contiguous_stride(&self) -> Self::Stride {
137+
self.data.contiguous_stride()
138+
}
139+
140+
fn split_at(self, _axis: Axis, _index: usize) -> (Self, Self) {
141+
unimplemented!();
142+
}
143+
}
144+

tests/azip.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,78 @@ fn test_zip_assign_into_cell() {
107107
assert_abs_diff_eq!(a2, &b + &c, epsilon = 1e-6);
108108
}
109109

110+
#[test]
111+
fn test_zip_collect_drop() {
112+
use std::cell::RefCell;
113+
use std::panic;
114+
115+
struct Recorddrop<'a>((usize, usize), &'a RefCell<Vec<(usize, usize)>>);
116+
117+
impl<'a> Drop for Recorddrop<'a> {
118+
fn drop(&mut self) {
119+
self.1.borrow_mut().push(self.0);
120+
}
121+
}
122+
123+
#[derive(Copy, Clone)]
124+
enum Config {
125+
CC,
126+
CF,
127+
FF,
128+
}
129+
130+
impl Config {
131+
fn a_is_f(self) -> bool {
132+
match self {
133+
Config::CC | Config::CF => false,
134+
_ => true,
135+
}
136+
}
137+
fn b_is_f(self) -> bool {
138+
match self {
139+
Config::CC => false,
140+
_ => true,
141+
}
142+
}
143+
}
144+
145+
let test_collect_panic = |config: Config, will_panic: bool, slice: bool| {
146+
let mut inserts = RefCell::new(Vec::new());
147+
let mut drops = RefCell::new(Vec::new());
148+
149+
let mut a = Array::from_shape_fn((5, 10).set_f(config.a_is_f()), |idx| idx);
150+
let mut b = Array::from_shape_fn((5, 10).set_f(config.b_is_f()), |_| 0);
151+
if slice {
152+
a = a.slice_move(s![.., ..-1]);
153+
b = b.slice_move(s![.., ..-1]);
154+
}
155+
156+
let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
157+
Zip::from(&a).and(&b).apply_collect(|&elt, _| {
158+
if elt.0 > 3 && will_panic {
159+
panic!();
160+
}
161+
inserts.borrow_mut().push(elt);
162+
Recorddrop(elt, &drops)
163+
});
164+
}));
165+
166+
println!("{:?}", inserts.get_mut());
167+
println!("{:?}", drops.get_mut());
168+
169+
assert_eq!(inserts.get_mut().len(), drops.get_mut().len(), "Incorrect number of drops");
170+
assert_eq!(inserts.get_mut(), drops.get_mut(), "Incorrect order of drops");
171+
};
172+
173+
for &should_panic in &[true, false] {
174+
for &should_slice in &[false, true] {
175+
test_collect_panic(Config::CC, should_panic, should_slice);
176+
test_collect_panic(Config::CF, should_panic, should_slice);
177+
test_collect_panic(Config::FF, should_panic, should_slice);
178+
}
179+
}
180+
}
181+
110182

111183
#[test]
112184
fn test_azip_syntax_trailing_comma() {

0 commit comments

Comments
 (0)
0