8000 ENH TfidfTransformer perserves np.float32 dtype by glemaitre · Pull Request #28136 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH TfidfTransformer perserves np.float32 dtype #28136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 18, 2024
4 changes: 4 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ Changelog
for storing the inverse document frequency.
:pr:`18843` by :user:`Paolo Montesel <thebabush>`.

- |Enhancement| :class:`feature_extraction.text.TfidfTransformer` now preserves
the data type of the input matrix if it is `np.float64` or `np.float32`.
:pr:`28136` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.impute`
.....................

Expand Down
27 changes: 17 additions & 10 deletions sklearn/feature_extraction/text.py
< 8000 tr data-hunk="e88b0d0308155614d8b1a27759fe46506efad29ffe9d2ad7fda0f33325abe25d" class="show-top-border">
Original file line number Diff line number Diff line change
Expand Up @@ -1666,23 +1666,21 @@ def fit(self, X, y=None):
)
if not sp.issparse(X):
X = sp.csr_matrix(X)
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
dtype = X.dtype if X.dtype in (np.float64, np.float32) else np.float64

if self.use_idf:
n_samples, n_features = X.shape
n_samples, _ = X.shape
df = _document_frequency(X)
df = df.astype(dtype, copy=False)

# perform idf smoothing if required
df += int(self.smooth_idf)
df += float(self.smooth_idf)
n_samples += int(self.smooth_idf)

# log+1 instead of log makes sure terms with zero idf don't get
# suppressed entirely.
# `np.log` preserves the dtype of `df` and thus `dtype`.
self.idf_ = np.log(n_samples / df) + 1.0
# FIXME: for backward compatibility, we force idf_ to be np.float64
# In the future, we should preserve the `dtype` of `X`.
self.idf_ = self.idf_.astype(np.float64, copy=False)

return self

Expand All @@ -1705,14 +1703,18 @@ def transform(self, X, copy=True):
"""
check_is_fitted(self)
X = self._validate_data(
X, accept_sparse="csr", dtype=FLOAT_DTYPES, copy=copy, reset=False
X,
accept_sparse="csr",
dtype=[np.float64, np.float32],
copy=copy,
reset=False,
)
if not sp.issparse(X):
X = sp.csr_matrix(X, dtype=np.float64)
X = sp.csr_matrix(X, dtype=X.dtype)

if self.sublinear_tf:
np.log(X.data, X.data)
X.data += 1
X.data += 1.0

if hasattr(self, "idf_"):
# the columns of X (CSR matrix) can be accessed with `X.indices `and
Expand All @@ -1725,7 +1727,12 @@ def transform(self, X, copy=True):
return X

def _more_tags(self):
return {"X_types": ["2darray", "sparse"]}
return {
"X_types": ["2darray", "sparse"],
# FIXME: np.float16 could be preserved if _inplace_csr_row_normalize_l2
# accepted it.
"preserves_dtype": [np.float64, np.float32],
}


class TfidfVectorizer(CountVectorizer):
Expand Down
0