6
6
import warnings
7
7
from collections import Counter
8
8
from functools import partial
9
+ from typing import Callable
9
10
10
11
import numpy as np
11
12
import numpy .ma as ma
@@ -163,7 +164,7 @@ class SimpleImputer(_BaseImputer):
163
164
nullable integer dtypes with missing values, `missing_values`
164
165
can be set to either `np.nan` or `pd.NA`.
165
166
166
- strategy : str, default='mean'
167
+ strategy : str or Callable , default='mean'
167
168
The imputation strategy.
168
169
169
170
- If "mean", then replace missing values using the mean along
@@ -175,10 +176,16 @@ class SimpleImputer(_BaseImputer):
175
176
If there is more than one such value, only the smallest is returned.
176
177
- If "constant", then replace missing values with fill_value. Can be
177
178
used with strings or numeric data.
179
+ - If an instance of Callable, then replace missing values using the
180
+ scalar statistic returned by running the callable over a dense 1d
181
+ array containing non-missing values of each column.
178
182
179
183
.. versionadded:: 0.20
180
184
strategy="constant" for fixed value imputation.
181
185
186
+ .. versionadded:: 1.5
187
+ strategy=callable for custom value imputation.
188
+
182
189
fill_value : str or numerical value, default=None
183
190
When strategy == "constant", `fill_value` is used to replace all
184
191
occurrences of missing_values. For string or object data types,
@@ -270,7 +277,10 @@ class SimpleImputer(_BaseImputer):
270
277
271
278
_parameter_constraints : dict = {
272
279
** _BaseImputer ._parameter_constraints ,
273
- "strategy" : [StrOptions ({"mean" , "median" , "most_frequent" , "constant" })],
280
+ "strategy" : [
281
+ StrOptions ({"mean" , "median" , "most_frequent" , "constant" }),
282
+ callable ,
283
+ ],
274
284
"fill_value" : "no_validation" , # any object is valid
275
285
"copy" : ["boolean" ],
276
286
}
@@ -456,6 +466,9 @@ def _sparse_fit(self, X, strategy, missing_values, fill_value):
456
466
elif strategy == "most_frequent" :
457
467
statistics [i ] = _most_frequent (column , 0 , n_zeros )
458
468
469
+ elif isinstance (strategy , Callable ):
470
+ statistics [i ] = self .strategy (column )
471
+
459
472
super ()._fit_indicator (missing_mask )
460
473
461
474
return statistics
@@ -518,6 +531,13 @@ def _dense_fit(self, X, strategy, missing_values, fill_value):
518
531
# fill_value in each column
519
532
return np .full (X .shape [1 ], fill_value , dtype = X .dtype )
520
533
534
+ # Custom
535
+ elif isinstance (strategy , Callable ):
536
+ statistics = np .empty (masked_X .shape [1 ])
537
+ for i in range (masked_X .shape [1 ]):
538
+ statistics [i ] = self .strategy (masked_X [:, i ].compressed ())
539
+ return statistics
540
+
521
541
def transform (self , X ):
522
542
"""Impute all missing values in `X`.
523
543
0 commit comments