10000 FIX KNNImputer missing indicator column addition when add_indicator=T… · scikit-learn/scikit-learn@c3b6609 · GitHub
[go: up one dir, main page]

Skip to content

Commit c3b6609

Browse files
Shreesha3112jeremiedbb
authored andcommitted
FIX KNNImputer missing indicator column addition when add_indicator=True (#26600)
Co-authored-by: shreesha3112 <shreesha3112.com>
1 parent 0d4cda5 commit c3b6609

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

doc/whats_new/v1.3.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ Version 1.3.1
1212
Changelog
1313
---------
1414

15+
:mod:`sklearn.impute`
16+
.....................
17+
18+
- |Fix| :class:`impute.KNNImputer` now correctly adds a missing indicator column in
19+
``transform`` when ``add_indicator`` is set to ``True`` and missing values are observed
20+
during ``fit``. :pr:`26600` by :user:`Shreesha Kumar Bhat <Shreesha3112>`.
21+
1522
:mod:`sklearn.neighbors`
1623
........................
1724

sklearn/impute/_knn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,12 @@ def transform(self, X):
282282
Xc[:, ~valid_mask] = 0
283283
else:
284284
Xc = X[:, valid_mask]
285-
return Xc
285+
286+
# Even if there are no missing values in X, we still concatenate Xc
287+
# with the missing value indicator matrix, X_indicator.
288+
# This is to ensure that the output maintains consistency in terms
289+
# of columns, regardless of whether missing values exist in X or not.
290+
return super()._concatenate_indicator(Xc, X_indicator)
286291

287292
row_missing_idx = np.flatnonzero(mask.any(axis=1))
288293

sklearn/impute/tests/test_common.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,39 @@ def test_keep_empty_features(imputer, keep_empty_features):
181181
assert X_imputed.shape == X.shape
182182
else:
183183
assert X_imputed.shape == (X.shape[0], X.shape[1] - 1)
184+
185+
186+
@pytest.mark.parametrize("imputer", imputers(), ids=lambda x: x.__class__.__name__)
187+
@pytest.mark.parametrize("missing_value_test", [np.nan, 1])
188+
def test_imputation_adds_missing_indicator_if_add_indicator_is_true(
189+
imputer, missing_value_test
190+
):
191+
"""Check that missing indicator always exists when add_indicator=True.
192+
193+
Non-regression test for gh-26590.
194+
"""
195+
X_train = np.array([[0, np.NaN], [1, 2]])
196+
197+
# Test data where missing_value_test variable can be set to np.NaN or 1.
198+
X_test = np.array([[0, missing_value_test], [1, 2]])
199+
200+
imputer.set_params(add_indicator=True)
201+
imputer.fit(X_train)
202+
203+
X_test_imputed_with_indicator = imputer.transform(X_test)
204+
assert X_test_imputed_with_indicator.shape == (2, 3)
205+
206+
imputer.set_params(add_indicator=False)
207+
imputer.fit(X_train)
208+
X_test_imputed_without_indicator = imputer.transform(X_test)
209+
assert X_test_imputed_without_indicator.shape == (2, 2)
210+
211+
assert_allclose(
212+
X_test_imputed_with_indicator[:, :-1], X_test_imputed_without_indicator
213+
)
214+
if np.isnan(missing_value_test):
215+
expected_missing_indicator = [1, 0]
216+
else:
217+
expected_missing_indicator = [0, 0]
218+
219+
assert_allclose(X_test_imputed_with_indicator[:, -1], expected_missing_indicator)

0 commit comments

Comments
 (0)
0