8000 DOC: Add take_along_axis to the see also section in argmin, argmax et… · numpy/numpy@7b8513f · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit 7b8513f

Browse files
mproszewskamattip
authored andcommitted
DOC: Add take_along_axis to the see also section in argmin, argmax etc. (#14799)
* Add take_along_axis to the see also section in argmin, argmax, argsort and argpartition and add examples
1 parent cadb066 commit 7b8513f

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

numpy/core/fromnumeric.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,9 @@ def argpartition(a, kth, axis=-1, kind='introselect', order=None):
796796
--------
797797
partition : Describes partition algorithms used.
798798
ndarray.partition : Inplace partition.
799-
argsort : Full indirect sort
799+
argsort : Full indirect sort.
800+
take_along_axis : Apply ``index_array`` from argpartition
801+
to an array as if by calling partition.
800802
801803
Notes
802804
-----
@@ -816,6 +818,14 @@ def argpartition(a, kth, axis=-1, kind='introselect', order=None):
816818
>>> np.array(x)[np.argpartition(x, 3)]
817819
array([2, 1, 3, 4])
818820
821+
Multi-dimensional array:
822+
823+
>>> x = np.array([[3, 4, 2], [1, 3, 1]])
824+
>>> index_array = np.argpartition(x, kth=1, axis=-1)
825+
>>> np.take_along_axis(x, index_array, axis=-1) # same as np.partition(x, kth=1)
826+
array([[2, 3, 4],
827+
[1, 1, 3]])
828+
819829
"""
820830
return _wrapfunc(a, 'argpartition', kth, axis=axis, kind=kind, order=order)
821831

@@ -1025,6 +1035,8 @@ def argsort(a, axis=-1, kind=None, order=None):
10251035
lexsort : Indirect stable sort with multiple keys.
10261036
ndarray.sort : Inplace sort.
10271037
argpartition : Indirect partial sort.
1038+
take_along_axis : Apply ``index_array`` from argsort
1039+
to an array as if by calling sort.
10281040
10291041
Notes
10301042
-----
@@ -1120,6 +1132,8 @@ def argmax(a, axis=None, out=None):
11201132
ndarray.argmax, argmin
11211133
amax : The maximum value along a given axis.
11221134
unravel_index : Convert a flat index into an index tuple.
1135+
take_along_axis : Apply ``np.expand_dims(index_array, axis)``
1136+
from argmax to an array as if by calling max.
11231137
11241138
Notes
11251139
-----
@@ -1154,6 +1168,16 @@ def argmax(a, axis=None, out=None):
11541168
>>> np.argmax(b) # Only the first occurrence is returned.
11551169
1
11561170
1171+
>>> x = np.array([[4,2,3], [1,0,3]])
1172+
>>> index_array = np.argmax(x, axis=-1)
1173+
>>> # Same as np.max(x, axis=-1, keepdims=True)
1174+
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
1175+
array([[4],
1176+
[3]])
1177+
>>> # Same as np.max(x, axis=-1)
1178+
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
1179+
array([4, 3])
1180+
11571181
"""
11581182
return _wrapfunc(a, 'argmax', axis=axis, out=out)
11591183

@@ -1189,6 +1213,8 @@ def argmin(a, axis=None, out=None):
11891213
ndarray.argmin, argmax
11901214
amin : The minimum value along a given axis.
11911215
unravel_index : Convert a flat index into an index tuple.
1216+
take_along_axis : Apply ``np.expand_dims(index_array, axis)``
1217+
from argmin to an array as if by calling min.
11921218
11931219
Notes
11941220
-----
@@ -1223,6 +1249,16 @@ def argmin(a, axis=None, out=None):
12231249
>>> np.argmin(b) # Only the first occurrence is returned.
12241250
0
12251251
1252+
>>> x = np.array([[4,2,3], [1,0,3]])
1253+
>>> index_array = np.argmin(x, axis=-1)
1254+
>>> # Same as np.min(x, axis=-1, keepdims=True)
1255+
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
1256+
array([[2],
1257+
[0]])
1258+
>>> # Same as np.max(x, axis=-1)
1259+
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
1260+
array([2, 0])
1261+
12261262
"""
12271263
return _wrapfunc(a, 'argmin', axis=axis, out=out)
12281264

0 commit comments

Comments
 (0)
0