8000 Merge pull request #2938 from jnothman/clean_impute · rmurcek/scikit-learn@14b435c · GitHub
[go: up one dir, main page]

Skip to content

Commit 14b435c

Browse files
committed
Merge pull request scikit-learn#2938 from jnothman/clean_impute
[MRG+1] some clean-up in Imputer, particularly in calculation of sparse median
2 parents 2fe838c + 79a8012 commit 14b435c

File tree

2 files changed

+67
-51
lines changed

2 files changed

+67
-51
lines changed

sklearn/preprocessing/imputation.py

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -33,42 +33,32 @@ def _get_mask(X, value_to_mask):
3333
return X == value_to_mask
3434

3535

36-
def _get_median(negative_elements, n_zeros, positive_elements):
37-
"""Compute the median of the array formed by negative_elements,
38-
n_zeros zeros and positive_elements. This function is used
39-
to support sparse matrices."""
40-
negative_elements = np.sort(negative_elements, kind='heapsort')
41-
positive_elements = np.sort(positive_elements, kind='heapsort')
42-
43-
n_elems = len(negative_elements) + n_zeros + len(positive_elements)
36+
def _get_median(data, n_zeros):
37+
"""Compute the median of data with n_zeros additional zeros.
38+
39+
This function is used to support sparse matrices; it modifies data in-place
40+
"""
41+
n_elems = len(data) + n_zeros
4442
if not n_elems:
4543
return np.nan
44+
n_negative = np.count_nonzero(data < 0)
45+
middle, is_odd = divmod(n_elems, 2)
46+
data.sort()
4647

47-
median_position = (n_elems - 1) / 2.0
48+
if is_odd:
49+
return _get_elem_at_rank(middle, data, n_negative, n_zeros)
4850

49-
if round(median_position) == median_position:
50-
median = _get_elem_at_rank(negative_elements, n_zeros,
51-
positive_elements, median_position)
52-
else:
53-
a = _get_elem_at_rank(negative_elements, n_zeros,
54-
positive_elements, math.floor(median_position))
55-
b = _get_elem_at_rank(negative_elements, n_zeros,
56-
positive_elements, math.ceil(median_position))
57-
median = (a + b) / 2.0
58-
59-
return median
60-
61-
62-
def _get_elem_at_rank(negative_elements, n_zeros, positive_elements, k):
63-
"""Compute the kth largest element of the array formed by
64-
negative_elements, n_zeros zeros and positive_elements."""
65-
len_neg = len(negative_elements)
66-
if k < len_neg:
67-
return negative_elements[k]
68-
elif k >= len_neg + n_zeros:
69-
return positive_elements[k - len_neg - n_zeros]
70-
else:
51+
return (_get_elem_at_rank(middle - 1, data, n_negative, n_zeros) +
52+
_get_elem_at_rank(middle, data, n_negative, n_zeros)) / 2.
53+
54+
55+
def _get_elem_at_rank(rank, data, n_negative, n_zeros):
56+
"""Find the value in data augmented with n_zeros for the given rank"""
57+
if rank < n_negative:
58+
return data[rank]
59+
if rank - n_negative < n_zeros:
7160
return 0
61+
return data[rank - n_zeros]
7262

7363

7464
def _most_frequent(array, extra_value, n_repeat):
@@ -137,8 +127,8 @@ class Imputer(BaseEstimator, TransformerMixin):
137127
138128
Attributes
139129
----------
140-
`statistics_` : array of shape (n_features,) or (n_samples,)
141-
The statistics along the imputation axis.
130+
`statistics_` : array of shape (n_features,)
131+
The imputation fill value for each feature if axis == 0.
142132
143133
Notes
144134
-----
@@ -211,7 +201,7 @@ def _sparse_fit(self, X, strategy, missing_values, axis):
211201

212202
# Count the zeros
213203
if missing_values == 0:
214-
n_zeros_axis = np.zeros(X.shape[not axis])
204+
n_zeros_axis = np.zeros(X.shape[not axis], dtype=int)
215205
else:
216206
n_zeros_axis = X.shape[axis] - np.diff(X.indptr)
217207

@@ -257,19 +247,15 @@ def _sparse_fit(self, X, strategy, missing_values, axis):
257247
mask_valids = np.hsplit(np.logical_not(mask_missing_values),
258248
X.indptr[1:-1])
259249

260-
columns = [col[mask.astype(np.bool)]
250+
# astype necessary for bug in numpy.hsplit before v1.9
251+
columns = [col[mask.astype(bool, copy=False)]
261252
for col, mask in zip(columns_all, mask_valids)]
262253

263254
# Median
264255
if strategy == "median":
265256
median = np.empty(len(columns))
266257
for i, column in enumerate(columns):
267-
268-
negatives = column[column < 0]
269-
positives = column[column > 0]
270-
median[i] = _get_median(negatives,
271-
n_zeros_axis[i],
272-
positives)
258+
median[i] = _get_median(column, n_zeros_axis[i])
273259

274260
return median
275261

sklearn/preprocessing/tests/test_imputation.py

Lines changed: 40 additions & 10 deletions
F438
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ def _check_statistics(X, X_true,
4343
assert_raises(ValueError, imputer.transform, X.copy().transpose())
4444
else:
4545
X_trans = imputer.transform(X.copy().transpose())
46-
assert_array_equal(imputer.statistics_, statistics,
47-
err_msg.format(1, False))
4846
assert_array_equal(X_trans, X_true.transpose(),
4947
err_msg.format(1, False))
5048

@@ -72,8 +70,6 @@ def _check_statistics(X, X_true,
7270
if sparse.issparse(X_trans):
7371
X_trans = X_trans.toarray()
7472

75-
assert_array_equal(imputer.statistics_, statistics,
76-
err_msg.format(1, True))
7773
assert_array_equal(X_trans, X_true.transpose(),
7874
err_msg.format(1, True))
7975

@@ -109,16 +105,20 @@ def test_imputation_mean_median_only_zero():
109105
])
110106
statistics_mean = [np.nan, 3, np.nan, np.nan, 7]
111107

108+
# Behaviour of median with NaN is undefined, e.g. different results in
109+
# np.median and np.ma.median
110+
X_for_median = X[:, [0, 1, 2, 4]]
112111
X_imputed_median = np.array([
113-
[2, 5, 5],
114-
[1, np.nan, 3],
115-
[2, 5, 5],
116-
[6, 5, 13],
112+
[2, 5],
113+
[1, 3],
114+
[2, 5],
115+
[6, 13],
117116
])
118-
statistics_median = [np.nan, 2, np.nan, 5, 5]
117+
statistics_median = [np.nan, 2, np.nan, 5]
119118

120119
_check_statistics(X, X_imputed_mean, "mean", statistics_mean, 0)
121-
_check_statistics(X, X_imputed_median, "median", statistics_median, 0)
120+
_check_statistics(X_for_median, X_imputed_median, "median",
121+
statistics_median, 0)
122122

123123

124124
def test_imputation_mean_median():
@@ -191,6 +191,36 @@ def test_imputation_mean_median():
191191
true_statistics, test_missing_values)
192192

193193

194+
def test_imputation_median_special_cases():
195+
"""Test median imputation with sparse boundary cases
196+
"""
197+
X = np.array([
198+
[0, np.nan, np.nan], # odd: implicit zero
199+
[5, np.nan, np.nan], # odd: explicit nonzero
200+
[0, 0, np.nan], # even: average two zeros
201+
[-5, 0, np.nan], # even: avg zero and neg
202+
[0, 5, np.nan], # even: avg zero and pos
203+
[4, 5, np.nan], # even: avg nonzeros
204+
[-4, -5, np.nan], # even: avg negatives
205+
[-1, 2, np.nan], # even: crossing neg and pos
206+
]).transpose()
207+
208+
X_imputed_median = np.array([
209+
[0, 0, 0],
210+
[5, 5, 5],
211+
[0, 0, 0],
212+
[-5, 0, -2.5],
213+
[0, 5, 2.5],
214+
[4, 5, 4.5],
215+
[-4, -5, -4.5],
216+
[-1, 2, .5],
217+
]).transpose()
218+
statistics_median = [0, 5, 0, -2.5, 2.5, 4.5, -4.5, .5]
219+
220+
_check_statistics(X, X_imputed_median, "median",
221+
statistics_median, 'NaN')
222+
223+
194224
def test_imputation_most_frequent():
195225
"""Test imputation using the most-frequent strategy."""
196226
X = np.array([

0 commit comments

Comments
 (0)
0