diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index d9896468959a1..56d4281eb0945 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -168,7 +168,7 @@ class MaskedArray(_MaskedArray): def _take_along_axis(arr, indices, axis): """Implements a simplified version of np.take_along_axis if numpy version < 1.15""" - if np_version > parse_version('1.14'): + if np_version >= parse_version('1.15'): return np.take_along_axis(arr=arr, indices=indices, axis=axis) else: if axis is None: