8000 FEAT Add custom imputation strategy to SimpleImputer (#28053) · punndcoder28/scikit-learn@e2b3785 · GitHub
[go: up one dir, main page]

Skip to content

Commit e2b3785

Browse files
mark-thmjnothman
andauthored
FEAT Add custom imputation strategy to SimpleImputer (scikit-learn#28053)
Co-authored-by: Joel Nothman <joeln@canva.com>
1 parent b5827cb commit e2b3785

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

doc/whats_new/v1.5.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ Changelog
2525
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
2626
where 123455 is the *pull request* number, not the issue number.
2727
28+
:mod:`sklearn.impute`
29+
.....................
30+
- |Enhancement| :class:`impute.SimpleImputer` now supports custom strategies
31+
by passing a function in place of a strategy name.
32+
:pr:`28053` by :user:`Mark Elliot <mark-thm>`.
33+
2834
Code and Documentation Contributors
2935
-----------------------------------
3036

sklearn/impute/_base.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
from collections import Counter
88
from functools import partial
9+
from typing import Callable
910

1011
import numpy as np
1112
import numpy.ma as ma
@@ -163,7 +164,7 @@ class SimpleImputer(_BaseImputer):
163164
nullable integer dtypes with missing values, `missing_values`
164165
can be set to either `np.nan` or `pd.NA`.
165166
166-
strategy : str, default='mean'
167+
strategy : str or Callable, default='mean'
167168
The imputation strategy.
168169
169170
- If "mean", then replace missing values using the mean along
@@ -175,10 +176,16 @@ class SimpleImputer(_BaseImputer):
175176
If there is more than one such value, only the smallest is returned.
176177
- If "constant", then replace missing values with fill_value. Can be
177178
used with strings or numeric data.
179+
- If an instance of Callable, then replace missing values using the
180+
scalar statistic returned by running the callable over a dense 1d
181+
array containing non-missing values of each column.
178182
179183
.. versionadded:: 0.20
180184
strategy="constant" for fixed value imputation.
181185
186+
.. versionadded:: 1.5
187+
strategy=callable for custom value imputation.
188+
182189
fill_value : str or numerical value, default=None
183190
When strategy == "constant", `fill_value` is used to replace all
184191
occurrences of missing_values. For string or object data types,
@@ -270,7 +277,10 @@ class SimpleImputer(_BaseImputer):
270277

271278
_parameter_constraints: dict = {
272279
**_BaseImputer._parameter_constraints,
273-
"strategy": [StrOptions({"mean", "median", "most_frequent", "constant"})],
280+
"strategy": [
281+
StrOptions({"mean", "median", "most_frequent", "constant"}),
282+
callable,
283+
],
274284
"fill_value": "no_validation", # any object is valid
275285
"copy": ["boolean"],
276286
}
@@ -456,6 +466,9 @@ def _sparse_fit(self, X, strategy, missing_values, fill_value):
456466
elif strategy == "most_frequent":
457467
statistics[i] = _most_frequent(column, 0, n_zeros)
458468

469+
elif isinstance(strategy, Callable):
470+
statistics[i] = self.strategy(column)
471+
459472
super()._fit_indicator(missing_mask)
460473

461474
return statistics
@@ -518,6 +531,13 @@ def _dense_fit(self, X, strategy, missing_values, fill_value):
518531
# fill_value in each column
519532
return np.full(X.shape[1], fill_value, dtype=X.dtype)
520533

534+
# Custom
535+
elif isinstance(strategy, Callable):
536+
statistics = np.empty(masked_X.shape[1])
537+
for i in range(masked_X.shape[1]):
538+
statistics[i] = self.strategy(masked_X[:, i].compressed())
539+
return statistics
540+
521541
def transform(self, X):
522542
"""Impute all missing values in `X`.
523543

sklearn/impute/tests/test_impute.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,3 +1710,37 @@ def test_simple_imputer_keep_empty_features(strategy, array_type, keep_empty_fea
17101710
assert_array_equal(constant_feature, 0)
17111711
else:
17121712
assert X_imputed.shape == (X.shape[0], X.shape[1] - 1)
1713+
1714+
1715+
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
1716+
def test_imputation_custom(csc_container):
1717+
X = np.array(
1718+
[
1719+
[1.1, 1.1, 1.1],
1720+
[3.9, 1.2, np.nan],
1721+
[np.nan, 1.3, np.nan],
1722+
[0.1, 1.4, 1.4],
1723+
[4.9, 1.5, 1.5],
1724+
[np.nan, 1.6, 1.6],
1725+
]
1726+
)
1727+
1728+
X_true = np.array(
1729+
[
1730+
[1.1, 1.1, 1.1],
1731+
[3.9, 1.2, 1.1],
1732+
[0.1, 1.3, 1.1],
1733+
[0.1, 1.4, 1.4],
1734+
[4.9, 1.5, 1.5],
1735+
[0.1, 1.6, 1.6],
1736+
]
1737+
)
1738+
1739+
imputer = SimpleImputer(missing_values=np.nan, strategy=np.min)
1740+
X_trans = imputer.fit_transform(X)
1741+
assert_array_equal(X_trans, X_true)
1742+
1743+
# Sparse matrix
1744+
imputer = SimpleImputer(missing_values=np.nan, strategy=np.min)
1745+
X_trans = imputer.fit_transform(csc_container(X))
1746+
assert_array_equal(X_trans.toarray(), X_true)

0 commit comments

Comments
 (0)
0