@@ -796,7 +796,9 @@ def argpartition(a, kth, axis=-1, kind='introselect', order=None):
796
796
--------
797
797
partition : Describes partition algorithms used.
798
798
ndarray.partition : Inplace partition.
799
- argsort : Full indirect sort
799
+ argsort : Full indirect sort.
800
+ take_along_axis : Apply ``index_array`` from argpartition
801
+ to an array as if by calling partition.
800
802
801
803
Notes
802
804
-----
@@ -816,6 +818,14 @@ def argpartition(a, kth, axis=-1, kind='introselect', order=None):
816
818
>>> np.array(x)[np.argpartition(x, 3)]
817
819
array([2, 1, 3, 4])
818
820
821
+ Multi-dimensional array:
822
+
823
+ >>> x = np.array([[3, 4, 2], [1, 3, 1]])
824
+ >>> index_array = np.argpartition(x, kth=1, axis=-1)
825
+ >>> np.take_along_axis(x, index_array, axis=-1) # same as np.partition(x, kth=1)
826
+ array([[2, 3, 4],
827
+ [1, 1, 3]])
828
+
819
829
"""
820
830
return _wrapfunc (a , 'argpartition' , kth , axis = axis , kind = kind , order = order )
821
831
@@ -1025,6 +1035,8 @@ def argsort(a, axis=-1, kind=None, order=None):
1025
1035
lexsort : Indirect stable sort with multiple keys.
1026
1036
ndarray.sort : Inplace sort.
1027
1037
argpartition : Indirect partial sort.
1038
+ take_along_axis : Apply ``index_array`` from argsort
1039
+ to an array as if by calling sort.
1028
1040
1029
1041
Notes
1030
1042
-----
@@ -1120,6 +1132,8 @@ def argmax(a, axis=None, out=None):
1120
1132
ndarray.argmax, argmin
1121
1133
amax : The maximum value along a given axis.
1122
1134
unravel_index : Convert a flat index into an index tuple.
1135
+ take_along_axis : Apply ``np.expand_dims(index_array, axis)``
1136
+ from argmax to an array as if by calling max.
1123
1137
1124
1138
Notes
1125
1139
-----
@@ -1154,6 +1168,16 @@ def argmax(a, axis=None, out=None):
1154
1168
>>> np.argmax(b) # Only the first occurrence is returned.
1155
1169
1
1156
1170
1171
+ >>> x = np.array([[4,2,3], [1,0,3]])
1172
+ >>> index_array = np.argmax(x, axis=-1)
1173
+ >>> # Same as np.max(x, axis=-1, keepdims=True)
1174
+ >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
1175
+ array([[4],
1176
+ [3]])
1177
+ >>> # Same as np.max(x, axis=-1)
1178
+ >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
1179
+ array([4, 3])
1180
+
1157
1181
"""
1158
1182
return _wrapfunc (a , 'argmax' , axis = axis , out = out )
1159
1183
@@ -1189,6 +1213,8 @@ def argmin(a, axis=None, out=None):
1189
1213
ndarray.argmin, argmax
1190
1214
amin : The minimum value along a given axis.
1191
1215
unravel_index : Convert a flat index into an index tuple.
1216
+ take_along_axis : Apply ``np.expand_dims(index_array, axis)``
1217
+ from argmin to an array as if by calling min.
1192
1218
1193
1219
Notes
1194
1220
-----
@@ -1223,6 +1249,16 @@ def argmin(a, axis=None, out=None):
1223
1249
>>> np.argmin(b) # Only the first occurrence is returned.
1224
1250
0
1225
1251
1252
+ >>> x = np.array([[4,2,3], [1,0,3]])
1253
+ >>> index_array = np.argmin(x, axis=-1)
1254
+ >>> # Same as np.min(x, axis=-1, keepdims=True)
1255
+ >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
1256
+ array([[2],
1257
+ [0]])
1258
+ >>> # Same as np.max(x, axis=-1)
1259
+ >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
1260
+ array([2, 0])
1261
+
1226
1262
"""
1227
1263
return _wrapfunc (a , 'argmin' , axis = axis , out = out )
1228
1264
0 commit comments