diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 96cbd21021f08..fb8fb5dfc1e7d 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -39,6 +39,10 @@ Changelog for storing the inverse document frequency. :pr:`18843` by :user:`Paolo Montesel `. +- |Enhancement| :class:`feature_extraction.text.TfidfTransformer` now preserves + the data type of the input matrix if it is `np.float64` or `np.float32`. + :pr:`28136` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.impute` ..................... diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index cef6f340e83c8..ea6686ef45eaa 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -1666,23 +1666,21 @@ def fit(self, X, y=None): ) if not sp.issparse(X): X = sp.csr_matrix(X) - dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64 + dtype = X.dtype if X.dtype in (np.float64, np.float32) else np.float64 if self.use_idf: - n_samples, n_features = X.shape + n_samples, _ = X.shape df = _document_frequency(X) df = df.astype(dtype, copy=False) # perform idf smoothing if required - df += int(self.smooth_idf) + df += float(self.smooth_idf) n_samples += int(self.smooth_idf) # log+1 instead of log makes sure terms with zero idf don't get # suppressed entirely. + # `np.log` preserves the dtype of `df` and thus `dtype`. self.idf_ = np.log(n_samples / df) + 1.0 - # FIXME: for backward compatibility, we force idf_ to be np.float64 - # In the future, we should preserve the `dtype` of `X`. - self.idf_ = self.idf_.astype(np.float64, copy=False) return self @@ -1705,14 +1703,18 @@ def transform(self, X, copy=True): """ check_is_fitted(self) X = self._validate_data( - X, accept_sparse="csr", dtype=FLOAT_DTYPES, copy=copy, reset=False + X, + accept_sparse="csr", + dtype=[np.float64, np.float32], + copy=copy, + reset=False, ) if not sp.issparse(X): - X = sp.csr_matrix(X, dtype=np.float64) + X = sp.csr_matrix(X, dtype=X.dtype) if self.sublinear_tf: np.log(X.data, X.data) - X.data += 1 + X.data += 1.0 if hasattr(self, "idf_"): # the columns of X (CSR matrix) can be accessed with `X.indices `and @@ -1725,7 +1727,12 @@ def transform(self, X, copy=True): return X def _more_tags(self): - return {"X_types": ["2darray", "sparse"]} + return { + "X_types": ["2darray", "sparse"], + # FIXME: np.float16 could be preserved if _inplace_csr_row_normalize_l2 + # accepted it. + "preserves_dtype": [np.float64, np.float32], + } class TfidfVectorizer(CountVectorizer):