8000 Merged in sergem/numpy/fixed_axis_in_partition (pull request #11) · prototype99/numpy@725dcd4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 725dcd4

Browse files
committed
Merged in sergem/numpy/fixed_axis_in_partition (pull request ganesh-k13#11)
Fixed_axis_in_partition
2 parents 24deb37 + 2d3b087 commit 725dcd4

File tree

2 files changed

+134
-22
lines changed

2 files changed

+134
-22
lines changed

numpy/core/_partition_use.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from __future__ import print_function
2+
13
from _partition import lib, ffi
4+
25
from _parition_build import list_suff, list_type
36
from numpy.core.multiarray import dtype
4-
from numpy import apply_along_axis
5-
from numpy import partition as numpy_partition
67

78
_type_to_suff = dict(zip(list_type, list_suff))
89
_dtype_to_cffi_type = {dtype('int32'): 'npy_int',
@@ -18,6 +19,70 @@ def _cffi_type(dtype_input):
1819
return _dtype_to_cffi_type.get(dtype_input)
1920

2021

22+
class IndexesOverAxis(object):
23+
"""
24+
Class for iterating over an array along one axis. Similar functionality is implemented in numpy.apply_along_axis.
25+
26+
>>> indexes = IndexesOverAxis((2,3,3), 1)
27+
>>> list(indexes)
28+
[(0, slice(None, None, None), 0), (0, slice(None, None, None), 1), (0, slice(None, None, None), 2), (1, slice(None, None, None), 0), (1, slice(None, None, None), 1), (1, slice(None, None, None), 2)]
29+
>>> indexes = IndexesOverAxis((2,2,3), 2)
30+
>>> list(indexes)
31+
[(0, 0, slice(None, None, None)), (0, 1, slice(None, None, None)), (1, 0, slice(None, None, None)), (1, 1, slice(None, None, None))]
32+
"""
33+
34+
def __init__(self, shape, axis):
35+
len_shape = len(shape)
36+
if len_shape <= 0:
37+
raise ValueError("Shape must have at least one dimension")
38+
39+
if axis < 0:
40+
axis += len_shape
41+
42+
if not (0 <= axis < len_shape):
43+
raise IndexError("Axis must be in 0..{}. Current value {}".format(len_shape, axis))
44+
45+
self.axis = axis
46+
self.limits = list(shape)
47+
self.limits[axis] = 0
48+
self.current_index_slice = [0] * len_shape
49+
50+
@staticmethod
51+
def _generate_next(array, limits):
52+
"""
53+
Performs per digit (per element) increment with overflow processing
54+
Assuming len(array) == len(limits), limits[x] >= 0 for each x.
55+
56+
Parameters
57+
----------
58+
array current state of array
59+
limits limits for each element.
60+
61+
Returns
62+
-------
63+
"""
64+
i = len(array) - 1
65+
array[i] += 1 # increment the last "digit"
66+
while array[i] >= limits[i]: # while overflow
67+
if i <= 0: # overflow in the last "digit" -> exit
68+
return False
69+
array[i] = 0
70+
array[i - 1] += 1
71+
i -= 1 # move to next "digit"
72+
return True
73+
74+
def _get_output(self):
75+
output = self.current_index_slice[:] # copy
76+
output[self.axis] = slice(None)
77+
return tuple(output)
78+
79+
def __iter__(self):
80+
while True:
81+
yield self._get_output()
82+
if not self._generate_next(self.current_index_slice, self.limits):
83+
return
84+
85+
2186
def _partition_for_1d(a, kth, kind='introselect', order=None):
2287
"""
2388
Performs in-place partition on 1D array.
@@ -35,9 +100,11 @@ def _partition_for_1d(a, kth, kind='introselect', order=None):
35100
-------
36101
37102
"""
38-
assert kind == 'introselect'
39-
assert order is None
40103
assert a.ndim == 1
104+
if kind != 'introselect':
105+
raise NotImplementedError("kind == '{}' is not implemented yet".format(kind))
106+
if order is not None:
107+
raise NotImplementedError("Only order == None is implemented")
41108

42109
str_dst_type = _cffi_type(a.dtype)
43110
if str_dst_type is None:
@@ -64,6 +131,13 @@ def get_pointer(np_arr):
64131
raise RuntimeError("Something goes wrong in partition")
65132

66133

134+
def _apply_inplace_along_axis(func1d, axis, arr, args=(), kwargs={}):
135+
for indexes in IndexesOverAxis(arr.shape, axis):
136+
extracted_axis = arr[indexes].copy()
137+
func1d(extracted_axis, *args, **kwargs)
138+
arr[indexes] = extracted_axis
139+
140+
67141
def partition(a, kth, axis=-1, kind='introselect', order=None):
68142
"""
69143
Performs partition inplace.
@@ -82,15 +156,15 @@ def partition(a, kth, axis=-1, kind='introselect', order=None):
82156
-------
83157
84158
"""
85-
if order is not None:
86-
raise NotImplementedError("Only order == None is implemented")
87-
if kind != 'introselect':
88-
raise NotImplementedError("kind == '{}' is not implemented yet".format(kind))
89159

90160
if a.size == 0:
91161
return None
92162

93-
if (axis == -1 or axis == a.ndim - 1) and a.ndim == 1:
94-
return _partition_for_1d(a, kth, kind, order)
95-
else:
96-
return apply_along_axis(numpy_partition, axis=axis, arr=a, kth=kth)
163+
try:
164+
if (axis == -1 or axis == a.ndim - 1) and a.ndim == 1:
165+
_partition_for_1d(a, kth, kind, order)
166+
else:
167+
_apply_inplace_along_axis(_partition_for_1d, axis=axis, arr=a, args=(),
168+
kwargs=dict(kth=kth, order=order, kind=kind))
169+
except NotImplementedError:
170+
a.sort(axis=axis, order=order)

numpy/core/tests/test_partition.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,69 @@ class TestPartition(TestCase):
1313
pivot = 2
1414
pivot_index = 2
1515

16+
array = np.array([[[5, 0, 3, 3],
17+
[7, 3, 5, 2],
18+
[4, 7, 6, 8]],
19+
20+
[[8, 1, 6, 7],
21+
[7, 8, 1, 5],
22+
[8, 4, 3, 0]]])
23+
1624
def test_uint16(self):
1725
arr = np.arange(10, -1, -1, dtype='int16')
1826

19-
if is_pypy():
20-
with self.assertRaises(NotImplementedError):
21-
np.partition(arr, self.pivot_index)
22-
else:
23-
partition = np.partition(arr, self.pivot_index)
24-
self._check_partition(partition, self.pivot_index, self.pivot)
27+
partition = np.partition(arr, self.pivot_index)
28+
self._check_partition(partition, self.pivot_index)
29+
self._check_content_along_axis(arr, partition, -1)
30+
2531

2632
def test_uint32(self):
2733
arr = np.arange(10, -1, -1, dtype='int32')
2834

2935
partition = np.partition(arr, self.pivot_index)
3036

31-
self._check_partition(partition, self.pivot_index, self.pivot)
37+
self._check_partition(partition, self.pivot_index)
38+
self._check_content_along_axis(arr, partition, -1)
3239

3340
def test_uint64(self):
3441
arr = np.arange(10, -1, -1, dtype='int64')
3542

3643
partition = np.partition(arr, self.pivot_index)
3744

38-
self._check_partition(partition, self.pivot_index, self.pivot)
45+
self._check_partition(partition, self.pivot_index)
46+
self._check_content_along_axis(arr, partition, -1)
3947

40-
def _check_partition(self, partition, pivot_index, pivot):
41-
self.assertTrue(np.all(partition[:pivot_index] < partition[pivot_index]))
48+
def _check_partition(self, partition, pivot_index):
49+
pivot = np.sort(partition)[pivot_index]
50+
self.assertTrue(np.all(partition[:pivot_index] <= partition[pivot_index]))
4251
self.assertTrue(np.all(partition[pivot_index:] >= partition[pivot_index]))
4352
self.assertTrue(partition[pivot_index] == pivot)
53+
return 0
54+
55+
def _check_multidimensional_partition(self, partition, axis, pivot_index):
56+
np.apply_along_axis(lambda x: self._check_partition(x, pivot_index), axis=axis, arr=partition)
57+
58+
def test_numpy_partition_doesnt_change_array(self):
59+
arr = self.array.copy()
60+
61+
np.partition(arr, 1)
62+
63+
self.assertTrue(np.array_equal(arr, self.array))
64+
65+
def test_multidimensional_axis_default(self):
66+
self._test_for_axis(axis=-1)
67+
68+
def test_multidimensional(self):
69+
self._test_for_axis(axis=0)
70+
self._test_for_axis(axis=1)
71+
self._test_for_axis(axis=2)
72+
73+
def _test_for_axis(self, axis):
74+
arr = self.array.copy()
75+
pivot_index = 1
76+
res = np.partition(arr, kth=pivot_index, axis=axis)
77+
self._check_content_along_axis(self.array, res, axis)
78+
self._check_multidimensional_partition(res, axis=axis, pivot_index=pivot_index)
79+
80+
def _check_content_along_axis(self, source, array, axis):
81+
self.assertTrue(np.array_equal(np.sort(source, axis=axis), np.sort(array, axis=axis)))

0 commit comments

Comments
 (0)
0