1
+ from __future__ import print_function
2
+
1
3
from _partition import lib , ffi
4
+
2
5
from _parition_build import list_suff , list_type
3
6
from numpy .core .multiarray import dtype
4
- from numpy import apply_along_axis
5
- from numpy import partition as numpy_partition
6
7
7
8
_type_to_suff = dict (zip (list_type , list_suff ))
8
9
_dtype_to_cffi_type = {dtype ('int32' ): 'npy_int' ,
@@ -18,6 +19,70 @@ def _cffi_type(dtype_input):
18
19
return _dtype_to_cffi_type .get (dtype_input )
19
20
20
21
22
+ class IndexesOverAxis (object ):
23
+ """
24
+ Class for iterating over an array along one axis. Similar functionality is implemented in numpy.apply_along_axis.
25
+
26
+ >>> indexes = IndexesOverAxis((2,3,3), 1)
27
+ >>> list(indexes)
28
+ [(0, slice(None, None, None), 0), (0, slice(None, None, None), 1), (0, slice(None, None, None), 2), (1, slice(None, None, None), 0), (1, slice(None, None, None), 1), (1, slice(None, None, None), 2)]
29
+ >>> indexes = IndexesOverAxis((2,2,3), 2)
30
+ >>> list(indexes)
31
+ [(0, 0, slice(None, None, None)), (0, 1, slice(None, None, None)), (1, 0, slice(None, None, None)), (1, 1, slice(None, None, None))]
32
+ """
33
+
34
+ def __init__ (self , shape , axis ):
35
+ len_shape = len (shape )
36
+ if len_shape <= 0 :
37
+ raise ValueError ("Shape must have at least one dimension" )
38
+
39
+ if axis < 0 :
40
+ axis += len_shape
41
+
42
+ if not (0 <= axis < len_shape ):
43
+ raise IndexError ("Axis must be in 0..{}. Current value {}" .format (len_shape , axis ))
44
+
45
+ self .axis = axis
46
+ self .limits = list (shape )
47
+ self .limits [axis ] = 0
48
+ self .current_index_slice = [0 ] * len_shape
49
+
50
+ @staticmethod
51
+ def _generate_next (array , limits ):
52
+ """
53
+ Performs per digit (per element) increment with overflow processing
54
+ Assuming len(array) == len(limits), limits[x] >= 0 for each x.
55
+
56
+ Parameters
57
+ ----------
58
+ array current state of array
59
+ limits limits for each element.
60
+
61
+ Returns
62
+ -------
63
+ """
64
+ i = len (array ) - 1
65
+ array [i ] += 1 # increment the last "digit"
66
+ while array [i ] >= limits [i ]: # while overflow
67
+ if i <= 0 : # overflow in the last "digit" -> exit
68
+ return False
69
+ array [i ] = 0
70
+ array [i - 1 ] += 1
71
+ i -= 1 # move to next "digit"
72
+ return True
73
+
74
+ def _get_output (self ):
75
+ output = self .current_index_slice [:] # copy
76
+ output [self .axis ] = slice (None )
77
+ return tuple (output )
78
+
79
+ def __iter__ (self ):
80
+ while True :
81
+ yield self ._get_output ()
82
+ if not self ._generate_next (self .current_index_slice , self .limits ):
83
+ return
84
+
85
+
21
86
def _partition_for_1d (a , kth , kind = 'introselect' , order = None ):
22
87
"""
23
88
Performs in-place partition on 1D array.
@@ -35,9 +100,11 @@ def _partition_for_1d(a, kth, kind='introselect', order=None):
35
100
-------
36
101
37
102
"""
38
- assert kind == 'introselect'
39
- assert order is None
40
103
assert a .ndim == 1
104
+ if kind != 'introselect' :
105
+ raise NotImplementedError ("kind == '{}' is not implemented yet" .format (kind ))
106
+ if order is not None :
107
+ raise NotImplementedError ("Only order == None is implemented" )
41
108
42
109
str_dst_type = _cffi_type (a .dtype )
43
110
if str_dst_type is None :
@@ -64,6 +131,13 @@ def get_pointer(np_arr):
64
131
raise RuntimeError ("Something goes wrong in partition" )
65
132
66
133
134
+ def _apply_inplace_along_axis (func1d , axis , arr , args = (), kwargs = {}):
135
+ for indexes in IndexesOverAxis (arr .shape , axis ):
136
+ extracted_axis = arr [indexes ].copy ()
137
+ func1d (extracted_axis , * args , ** kwargs )
138
+ arr [indexes ] = extracted_axis
139
+
140
+
67
141
def partition (a , kth , axis = - 1 , kind = 'introselect' , order = None ):
68
142
"""
69
143
Performs partition inplace.
@@ -82,15 +156,15 @@ def partition(a, kth, axis=-1, kind='introselect', order=None):
82
156
-------
83
157
84
158
"""
85
- if order is not None :
86
- raise NotImplementedError ("Only order == None is implemented" )
87
- if kind != 'introselect' :
88
- raise NotImplementedError ("kind == '{}' is not implemented yet" .format (kind ))
89
159
90
160
if a .size == 0 :
91
161
return None
92
162
93
- if (axis == - 1 or axis == a .ndim - 1 ) and a .ndim == 1 :
94
- return _partition_for_1d (a , kth , kind , order )
95
- else :
96
- return apply_along_axis (numpy_partition , axis = axis , arr = a , kth = kth )
163
+ try :
164
+ if (axis == - 1 or axis == a .ndim - 1 ) and a .ndim == 1 :
165
+ _partition_for_1d (a , kth , kind , order )
166
+ else :
167
+ _apply_inplace_along_axis (_partition_for_1d , axis = axis , arr = a , args = (),
168
+ kwargs = dict (kth = kth , order = order , kind = kind ))
169
+ except NotImplementedError :
170
+ a .sort (axis = axis , order = order )
0 commit comments