10
10
11
11
from ..base import BaseEstimator , TransformerMixin
12
12
from ..utils import check_array
13
- from ..utils import as_float_array
14
13
from ..utils .fixes import astype
15
14
from ..utils .sparsefuncs import _get_median
16
15
from ..utils .validation import check_is_fitted
16
+ from ..utils .validation import FLOAT_DTYPES
17
17
18
18
from ..externals import six
19
19
@@ -310,15 +310,12 @@ def transform(self, X):
310
310
if self .axis == 0 :
311
311
check_is_fitted (self , 'statistics_' )
312
312
313
- # Copy just once
314
- X = as_float_array (X , copy = self .copy , force_all_finite = False )
315
-
316
313
# Since two different arrays can be provided in fit(X) and
317
314
# transform(X), the imputation data need to be recomputed
318
315
# when the imputation is done per sample
319
316
if self .axis == 1 :
320
- X = check_array (X , accept_sparse = 'csr' , force_all_finite = False ,
321
- copy = False )
317
+ X = check_array (X , accept_sparse = 'csr' , dtype = FLOAT_DTYPES ,
318
+ force_all_finite = False , copy = self . copy )
322
319
323
320
if sparse .issparse (X ):
324
321
statistics = self ._sparse_fit (X ,
@@ -332,8 +329,8 @@ def transform(self, X):
332
329
self .missing_values ,
333
330
self .axis )
334
331
else :
335
- X = check_array (X , accept_sparse = 'csc' , force_all_finite = False ,
336
- copy = False )
332
+ X = check_array (X , accept_sparse = 'csc' , dtype = FLOAT_DTYPES ,
333
+ force_all_finite = False , copy = self . copy )
337
334
statistics = self .statistics_
338
335
339
336
# Delete the invalid rows/columns
0 commit comments