@@ -1666,23 +1666,21 @@ def fit(self, X, y=None):
1666
1666
)
1667
1667
if not sp .issparse (X ):
1668
1668
X = sp .csr_matrix (X )
1669
- dtype = X .dtype if X .dtype in FLOAT_DTYPES else np .float64
1669
+ dtype = X .dtype if X .dtype in ( np . float64 , np . float32 ) else np .float64
1670
1670
1671
1671
if self .use_idf :
1672
- n_samples , n_features = X .shape
1672
+ n_samples , _ = X .shape
1673
1673
df = _document_frequency (X )
1674
1674
df = df .astype (dtype , copy = False )
1675
1675
1676
1676
# perform idf smoothing if required
1677
- df += int (self .smooth_idf )
1677
+ df += float (self .smooth_idf )
1678
1678
n_samples += int (self .smooth_idf )
1679
1679
1680
1680
# log+1 instead of log makes sure terms with zero idf don't get
1681
1681
# suppressed entirely.
1682
+ # `np.log` preserves the dtype of `df` and thus `dtype`.
1682
1683
self .idf_ = np .log (n_samples / df ) + 1.0
1683
- # FIXME: for backward compatibility, we force idf_ to be np.float64
1684
- # In the future, we should preserve the `dtype` of `X`.
1685
- self .idf_ = self .idf_ .astype (np .float64 , copy = False )
1686
1684
1687
1685
return self
1688
1686
@@ -1705,14 +1703,18 @@ def transform(self, X, copy=True):
1705
1703
"""
1706
1704
check_is_fitted (self )
1707
1705
X = self ._validate_data (
1708
- X , accept_sparse = "csr" , dtype = FLOAT_DTYPES , copy = copy , reset = False
1706
+ X ,
1707
+ accept_sparse = "csr" ,
1708
+ dtype = [np .float64 , np .float32 ],
1709
+ copy = copy ,
1710
+ reset = False ,
1709
1711
)
1710
1712
if not sp .issparse (X ):
1711
- X = sp .csr_matrix (X , dtype = np . float64 )
1713
+ X = sp .csr_matrix (X , dtype = X . dtype )
1712
1714
1713
1715
if self .sublinear_tf :
1714
1716
np .log (X .data , X .data )
1715
- X .data += 1
1717
+ X .data += 1.0
1716
1718
1717
1719
if hasattr (self , "idf_" ):
1718
1720
# the columns of X (CSR matrix) can be accessed with `X.indices `and
@@ -1725,7 +1727,12 @@ def transform(self, X, copy=True):
1725
1727
return X
1726
1728
1727
1729
def _more_tags (self ):
1728
- return {"X_types" : ["2darray" , "sparse" ]}
1730
+ return {
1731
+ "X_types" : ["2darray" , "sparse" ],
1732
+ # FIXME: np.float16 could be preserved if _inplace_csr_row_normalize_l2
1733
+ # accepted it.
1734
+ "preserves_dtype" : [np .float64 , np .float32 ],
1735
+ }
1729
1736
1730
1737
1731
1738
class TfidfVectorizer (CountVectorizer ):
0 commit comments