8000 FIX change stats.mode to Counter for better performance with object d… · scikit-learn/scikit-learn@0937b4a · GitHub
[go: up one dir, main page]

Skip to content

Commit 0937b4a

Browse files
authored
FIX change stats.mode to Counter for better performance with object dtype (#18987)
1 parent d953f74 commit 0937b4a

File tree

3 files changed

+46
-13
lines changed

3 files changed

+46
-13
lines changed

doc/whats_new/v0.24.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ Changelog
369369
estimator's `random_state` attribute, allowing to use it with more external classes.
370370
:pr:`15636` by :user:`David Cortes <david-cortes>`.
371371

372+
- |Efficiency| :class:`impute.SimpleImputer` is now faster with `object` dtype array.
373+
when `strategy='most_frequent'` in :class:`~sklearn.impute.SimpleImputer`.
374+
:pr:`18987` by :user:`David Katz <DavidKatz-il>`.
375+
372376
:mod:`sklearn.inspection`
373377
.........................
374378

sklearn/impute/_base.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numbers
66
import warnings
7+
from collections import Counter
78

89
import numpy as np
910
import numpy.ma as ma
@@ -34,15 +35,20 @@ def _most_frequent(array, extra_value, n_repeat):
3435
of the array."""
3536
# Compute the most frequent value in array only
3637
if array.size > 0:
37-
with warnings.catch_warnings():
38-
# stats.mode raises a warning when input array contains objects due
39-
# to incapacity to detect NaNs. Irrelevant here since input array
40-
# has already been NaN-masked.
41-
warnings.simplefilter("ignore", RuntimeWarning)
38+
if array.dtype == object:
39+
# scipy.stats.mode is slow with object dtype array.
40+
# Python Counter is more efficient
41+
counter = Counter(array)
42+
most_frequent_count = counter.most_common(1)[0][1]
43+
# tie breaking similarly to scipy.stats.mode
44+
most_frequent_value = min(
45+
value for value, count in counter.items()
46+
if count == most_frequent_count
47+
)
48+
else:
4249
mode = stats.mode(array)
43-
44-
most_frequent_value = mode[0][0]
45-
most_frequent_count = mode[1][0]
50+
most_frequent_value = mode[0][0]
51+
most_frequent_count = mode[1][0]
4652
else:
4753
most_frequent_value = 0
4854
most_frequent_count = 0
@@ -55,11 +61,8 @@ def _most_frequent(array, extra_value, n_repeat):
5561
elif most_frequent_count > n_repeat:
5662
return most_frequent_value
5763
elif most_frequent_count == n_repeat:
58-
# Ties the breaks. Copy the behaviour of scipy.stats.mode
59-
if most_frequent_value < extra_value:
60-
return most_frequent_value
61-
else:
62-
return extra_value
64+
# tie breaking similarly to scipy.stats.mode
65+
return min(most_frequent_value, extra_value)
6366

6467

6568
class _BaseImputer(TransformerMixin, BaseEstimator):

sklearn/impute/tests/test_impute.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sklearn import tree
2828
from sklearn.random_projection import _sparse_random_matrix
2929
from sklearn.exceptions import ConvergenceWarning
30+
from sklearn.impute._base import _most_frequent
3031

3132

3233
def _check_statistics(X, X_true,
@@ -1474,3 +1475,28 @@ def test_simple_imputation_inverse_transform_exceptions(missing_value):
14741475
with pytest.raises(ValueError,
14751476
match=f"Got 'add_indicator={imputer.add_indicator}'"):
14761477
imputer.inverse_transform(X_1_trans)
1478+
1479+
1480+
@pytest.mark.parametrize(
1481+
"expected,array,dtype,extra_value,n_repeat",
1482+
[
1483+
# array of object dtype
1484+
("extra_value", ['a', 'b', 'c'], object, "extra_value", 2),
1485+
(
1486+
"most_frequent_value",
1487+
['most_frequent_value', 'most_frequent_value', 'value'],
1488+
object, "extra_value", 1
1489+
),
1490+
("a", ['min_value', 'min_value' 'value'], object, "a", 2),
1491+
("min_value", ['min_value', 'min_value', 'value'], object, "z", 2),
1492+
# array of numeric dtype
1493+
(10, [1, 2, 3], int, 10, 2),
1494+
(1, [1, 1, 2], int, 10, 1),
1495+
(10, [20, 20, 1], int, 10, 2),
1496+
(1, [1, 1, 20], int, 10, 2),
1497+
]
1498+
)
1499+
def test_most_frequent(expected, array, dtype, extra_value, n_repeat):
1500+
assert expected == _most_frequent(
1501+
np.array(array, dtype=dtype), extra_value, n_repeat
1502+
)

0 commit comments

Comments
 (0)
0