8000 ENH TfidfTransformer perserves np.float32 dtype (#28136) · scikit-learn/scikit-learn@fe5ba6f · GitHub
[go: up one dir, main page]

Skip to content

Commit fe5ba6f

Browse files
authored
ENH TfidfTransformer perserves np.float32 dtype (#28136)
1 parent d418e79 commit fe5ba6f

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

doc/whats_new/v1.5.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ Changelog
4747
for storing the inverse document frequency.
4848
:pr:`18843` by :user:`Paolo Montesel <thebabush>`.
4949

50+
- |Enhancement| :class:`feature_extraction.text.TfidfTransformer` now preserves
51+
the data type of the input matrix if it is `np.float64` or `np.float32`.
52+
:pr:`28136` by :user:`Guillaume Lemaitre <glemaitre>`.
53+
5054
:mod:`sklearn.impute`
5155
.....................
5256

sklearn/feature_extraction/text.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,23 +1666,21 @@ def fit(self, X, y=None):
16661666
)
16671667
if not sp.issparse(X):
16681668
X = sp.csr_matrix(X)
1669-
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
1669+
dtype = X.dtype if X.dtype in (np.float64, np.float32) else np.float64
16701670

16711671
if self.use_idf:
1672-
n_samples, n_features = X.shape
1672+
n_samples, _ = X.shape
16731673
df = _document_frequency(X)
16741674
df = df.astype(dtype, copy=False)
16751675

16761676
# perform idf smoothing if required
1677-
df += int(self.smooth_idf)
1677+
df += float(self.smooth_idf)
16781678
n_samples += int(self.smooth_idf)
16791679

16801680
# log+1 instead of log makes sure terms with zero idf don't get
16811681
# suppressed entirely.
1682+
# `np.log` preserves the dtype of `df` and thus `dtype`.
16821683
self.idf_ = np.log(n_samples / df) + 1.0
1683-
# FIXME: for backward compatibility, we force idf_ to be np.float64
1684-
# In the future, we should preserve the `dtype` of `X`.
1685-
self.idf_ = self.idf_.astype(np.float64, copy=False)
16861684

16871685
return self
16881686

@@ -1705,14 +1703,18 @@ def transform(self, X, copy=True):
17051703
"""
17061704
check_is_fitted(self)
17071705
X = self._validate_data(
1708-
X, accept_sparse="csr", dtype=FLOAT_DTYPES, copy=copy, reset=False
1706+
X,
1707+
accept_sparse="csr",
1708+
dtype=[np.float64, np.float32],
1709+
copy=copy,
1710+
reset=False,
17091711
)
17101712
if not sp.issparse(X):
1711-
X = sp.csr_matrix(X, dtype=np.float64)
1713+
X = sp.csr_matrix(X, dtype=X.dtype)
17121714

17131715
if self.sublinear_tf:
17141716
np.log(X.data, X.data)
1715-
X.data += 1
1717+
X.data += 1.0
17161718

17171719
if hasattr(self, "idf_"):
17181720
# the columns of X (CSR matrix) can be accessed with `X.indices `and
@@ -1725,7 +1727,12 @@ def transform(self, X, copy=True):
17251727
return X
17261728

17271729
def _more_tags(self):
1728-
return {"X_types": ["2darray", "sparse"]}
1730+
return {
1731+
"X_types": ["2darray", "sparse"],
1732+
# FIXME: np.float16 could be preserved if _inplace_csr_row_normalize_l2
1733+
# accepted it.
1734+
"preserves_dtype": [np.float64, np.float32],
1735+
}
17291736

17301737

17311738
class TfidfVectorizer(CountVectorizer):

0 commit comments

Comments
 (0)
0