8000 ENH Allow for appropriate dtype us in `preprocessing.PolynomialFeatur… · scikit-learn/scikit-learn@c6628a0 · GitHub
[go: up one dir, main page]

Skip to content

Commit c6628a0

Browse files
Micky774niuk-aogriseljjerphanthomasjpfan
authored
ENH Allow for appropriate dtype us in preprocessing.PolynomialFeatures for sparse matrices (#23731)
Co-authored-by: Aleksandr Kokhaniukov <alexander.kohanyukov@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 31c8c75 commit c6628a0

File tree

5 files changed

+637
-107
lines changed

5 files changed

+637
-107
lines changed

doc/whats_new/v1.3.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,15 @@ Changelog
510510
during `transform` with no prior call to `fit` or `fit_transform`.
511511
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.
512512

513+
- |Enhancement| :class:`preprocessing.PolynomialFeatures` now calculates the
514+
number of expanded terms a-priori when dealing with sparse `csr` matrices
515+
in order to optimize the choice of `dtype` for `indices` and `indptr`. It
516+
can now output `csr` matrices with `np.int32` `indices/indptr` components
517+
when there are few enough elements, and will automatically use `np.int64`
518+
for sufficiently large matrices.
519+
:pr:`20524` by :user:`niuk-a <niuk-a>` and
520+
:pr:`23731` by :user:`Meekail Zain <micky774>`
521+
513522
- |API| A `FutureWarning` is now raised when instantiating a class which inherits from
514523
a deprecated base class (i.e. decorated by :class:`utils.deprecated`) and which
515524
overrides the `__init__` method.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def check_package_status(package, min_version):
293293
},
294294
],
295295
"preprocessing": [
296-
{"sources": ["_csr_polynomial_expansion.pyx"], "include_np": True},
296+
{"sources": ["_csr_polynomial_expansion.pyx"]},
297297
{
298298
"sources": ["_target_encoder_fast.pyx"],
299299
"include_np": True,
Lines changed: 186 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,178 @@
1-
# Author: Andrew nystrom <awnystrom@gmail.com>
1+
# Authors: Andrew nystrom <awnystrom@gmail.com>
2+
# Meekail Zain <zainmeekail@gmail.com>
3+
from ..utils._typedefs cimport uint8_t, int64_t, intp_t
24

3-
from scipy.sparse import csr_matrix
4-
cimport numpy as cnp
5-
import numpy as np
5+
ctypedef uint8_t FLAG_t
6+
7+
# We use the following verbatim block to determine whether the current
8+
# platform's compiler supports 128-bit integer values intrinsically.
9+
# This should work for GCC and CLANG on 64-bit architectures, but doesn't for
10+
# MSVC on any architecture. We prefer to use 128-bit integers when possible
11+
# because the intermediate calculations have a non-trivial risk of overflow. It
12+
# is, however, very unlikely to come up on an average use case, hence 64-bit
13+
# integers (i.e. `long long`) are "good enough" for most common cases. There is
14+
# not much we can do to efficiently mitigate the overflow risk on the Windows
15+
# platform at this time. Consider this a "best effort" design decision that
16+
# could be revisited later in case someone comes up with a safer option that
17+
# does not hurt the performance of the common cases.
18+
# See `test_sizeof_LARGEST_INT_t()`for more information on exact type expectations.
19+
cdef extern from *:
20+
"""
21+
#ifdef __SIZEOF_INT128__
22+
typedef __int128 LARGEST_INT_t;
23+
#elif (__clang__ || __EMSCRIPTEN__) && !__i386__
24+
typedef _BitInt(128) LARGEST_INT_t;
25+
#else
26+
typedef long long LARGEST_INT_t;
27+
#endif
28+
"""
29+
ctypedef long long LARGEST_INT_t
30+
31+
32+
# Determine the size of `LARGEST_INT_t` at runtime.
33+
# Used in `test_sizeof_LARGEST_INT_t`.
34+
def _get_sizeof_LARGEST_INT_t():
35+
return sizeof(LARGEST_INT_t)
636

7-
cnp.import_array()
837

9-
# TODO: use `cnp.{int,float}{32,64}` when cython#5230 is resolved:
38+
# TODO: use `{int,float}{32,64}_t` when cython#5230 is resolved:
1039
# https://github.com/cython/cython/issues/5230
11-
ctypedef fused DATA_T:
40+
ctypedef fused DATA_t:
1241
float
1342
double
1443
int
15-
long
44+
long long
45+
# INDEX_{A,B}_t are defined to generate a proper Cartesian product
46+
# of types through Cython fused-type expansion.
47+
ctypedef fused INDEX_A_t:
48+
signed int
49+
signed long long
50+
ctypedef fused INDEX_B_t:
51+
signed int
52+
signed long long
1653

17-
18-
cdef inline cnp.int32_t _deg2_column(
19-
cnp.int32_t d,
20-
cnp.int32_t i,
21-
cnp.int32_t j,
22-
cnp.int32_t interaction_only,
23-
) noexcept nogil:
54+
cdef inline int64_t _deg2_column(
55+
LARGEST_INT_t n_features,
56+
LARGEST_INT_t i,
57+
LARGEST_INT_t j,
58+
FLAG_t interaction_only
59+
) nogil:
2460
"""Compute the index of the column for a degree 2 expansion
2561
26-
d is the dimensionality of the input data, i and j are the indices
62+
n_features is the dimensionality of the input data, i and j are the indices
2763
for the columns involved in the expansion.
2864
"""
2965
if interaction_only:
30-
return d * i - (i**2 + 3 * i) / 2 - 1 + j
66+
return n_features * i - i * (i + 3) / 2 - 1 + j
3167
else:
32-
return d * i - (i**2 + i) / 2 + j
68+
return n_features * i - i* (i + 1) / 2 + j
3369

3470

35-
cdef inline cnp.int32_t _deg3_column(
36-
cnp.int32_t d,
37-
cnp.int32_t i,
38-
cnp.int32_t j,
39-
cnp.int32_t k,
40-
cnp.int32_t interaction_only
41-
) noexcept nogil:
71+
cdef inline int64_t _deg3_column(
72+
LARGEST_INT_t n_features,
73+
LARGEST_INT_t i,
74+
LARGEST_INT_t j,
75+
LARGEST_INT_t k,
76+
FLAG_t interaction_only
77+
) nogil:
4278
"""Compute the index of the column for a degree 3 expansion
4379
44-
d is the dimensionality of the input data, i, j and k are the indices
80+
n_features is the dimensionality of the input data, i, j and k are the indices
4581
for the columns involved in the expansion.
4682
"""
4783
if interaction_only:
48-
return ((3 * d**2 * i - 3 * d * i**2 + i**3
49-
+ 11 * i - 3 * j**2 - 9 * j) / 6
50-
+ i**2 - 2 * d * i + d * j - d + k)
84+
return (
85+
(
86+
(3 * n_features) * (n_features * i - i**2)
87+
+ i * (i**2 + 11) - (3 * j) * (j + 3)
88+
) / 6 + i**2 + n_features * (j - 1 - 2 * i) + k
89+
)
90+
else:
91+
return (
92+
(
93+
(3 * n_features) * (n_features * i - i**2)
94+
+ i ** 3 - i - (3 * j) * (j + 1)
95+
) / 6 + n_features * j + k
96+
)
97+
98+
99+
def py_calc_expanded_nnz_deg2(n, interaction_only):
100+
return n * (n + 1) // 2 - interaction_only * n
101+
102+
103+
def py_calc_expanded_nnz_deg3(n, interaction_only):
104+
return n * (n**2 + 3 * n + 2) // 6 - interaction_only * n**2
105+
106+
107+
cpdef int64_t _calc_expanded_nnz(
108+
LARGEST_INT_t n,
109+
FLAG_t interaction_only,
110+
LARGEST_INT_t degree
111+
):
112+
"""
113+
Calculates the number of non-zero interaction terms generated by the
114+
non-zero elements of a single row.
115+
"""
116+
# This is the maximum value before the intermediate computation
117+
# d**2 + d overflows
118+
# Solution to d**2 + d = maxint64
119+
# SymPy: solve(x**2 + x - int64_max, x)
120+
cdef int64_t MAX_SAFE_INDEX_CALC_DEG2 = 3037000499
121+
122+
# This is the maximum value before the intermediate computation
123+
# d**3 + 3 * d**2 + 2*d overflows
124+
# Solution to d**3 + 3 * d**2 + 2*d = maxint64
125+
# SymPy: solve(x * (x**2 + 3 * x + 2) - int64_max, x)
126+
cdef int64_t MAX_SAFE_INDEX_CALC_DEG3 = 2097151
127+
128+
if degree == 2:
129+
# Only need to check when not using 128-bit integers
130+
if sizeof(LARGEST_INT_t) < 16 and n <= MAX_SAFE_INDEX_CALC_DEG2:
131+
return n * (n + 1) / 2 - interaction_only * n
132+
return <int64_t> py_calc_expanded_nnz_deg2(n, interaction_only)
51133
else:
52-
return ((3 * d**2 * i - 3 * d * i**2 + i ** 3 - i
53-
- 3 * j**2 - 3 * j) / 6
54-
+ d * j + k)
55-
56-
57-
def _csr_polynomial_expansion(
58-
const DATA_T[:] data,
59-
const cnp.int32_t[:] indices,
60-
const cnp.int32_t[:] indptr,
61-
cnp.int32_t d,
62-
cnp.int32_t interaction_only,
63-
cnp.int32_t degree
134+
# Only need to check when not using 128-bit integers
135+
if sizeof(LARGEST_INT_t) < 16 and n <= MAX_SAFE_INDEX_CALC_DEG3:
136+
return n * (n**2 + 3 * n + 2) / 6 - interaction_only * n**2
137+
return <int64_t> py_calc_expanded_nnz_deg3(n, interaction_only)
138+
139+
cpdef int64_t _calc_total_nnz(
140+
INDEX_A_t[:] indptr,
141+
FLAG_t interaction_only,
142+
int64_t degree,
64143
):
65144
"""
66-
Perform a second-degree polynomial or interaction expansion on a scipy
145+
Calculates the number of non-zero interaction terms generated by the
146+
non-zero elements across all rows for a single degree.
147+
"""
148+
cdef int64_t total_nnz=0
149+
cdef intp_t row_idx
150+
for row_idx in range(len(indptr) - 1):
151+
total_nnz += _calc_expanded_nnz(
152+
indptr[row_idx + 1] - indptr[row_idx],
153+
interaction_only,
154+
degree
155+
)
156+
return total_nnz
157+
158+
159+
cpdef void _csr_polynomial_expansion(
160+
const DATA_t[:] data, # IN READ-ONLY
161+
const INDEX_A_t[:] indices, # IN READ-ONLY
162+
const INDEX_A_t[:] indptr, # IN READ-ONLY
163+
INDEX_A_t n_features,
164+
DATA_t[:] result_data, # OUT
165+
INDEX_B_t[:] result_indices, # OUT
166+
INDEX_B_t[:] result_indptr, # OUT
167+
FLAG_t interaction_only,
168+
FLAG_t degree
169+
) nogil:
170+
"""
171+
Perform a second or third degree polynomial or interaction expansion on a
67172
compressed sparse row (CSR) matrix. The method used only takes products of
68-
non-zero features. For a matrix with density d, this results in a speedup
69-
on the order of d^k where k is the degree of the expansion, assuming all
70-
rows are of similar density.
173+
non-zero features. For a matrix with density :math:`d`, this results in a
174+
speedup on the order of :math:`(1/d)^k` where :math:`k` is the degree of
175+
the expansion, assuming all rows are of similar density.
71176
72177
Parameters
73178
----------
@@ -80,9 +185,21 @@ def _csr_polynomial_expansion(
80185
indptr : memory view on nd-array
81186
The "indptr" attribute of the input CSR matrix.
82187
83-
d : int
188+
n_features : int
84189
The dimensionality of the input CSR matrix.
85190
191+
result_data : nd-array
192+
The output CSR matrix's "data" attribute.
193+
It is modified by this routine.
194+
195+
result_indices : nd-array
196+
The output CSR matrix's "indices" attribute.
197+
It is modified by this routine.
198+
199+
result_indptr : nd-array
200+
The output CSR matrix's "indptr" attribute.
201+
It is modified by this routine.
202+
86203
interaction_only : int
87204
0 for a polynomial expansion, 1 for an interaction expansion.
88205
@@ -95,47 +212,11 @@ def _csr_polynomial_expansion(
95212
Matrices Using K-Simplex Numbers" by Andrew Nystrom and John Hughes.
96213
"""
97214

98-
assert degree in (2, 3)
99-
100-
if degree == 2:
101-
expanded_dimensionality = int((d**2 + d) / 2 - interaction_only*d)
102-
else:
103-
expanded_dimensionality = int((d**3 + 3*d**2 + 2*d) / 6
104-
- interaction_only*d**2)
105-
if expanded_dimensionality == 0:
106-
return None
107-
assert expanded_dimensionality > 0
108-
109-
cdef cnp.int32_t total_nnz = 0, row_i, nnz
110-
111-
# Count how many nonzero elements the expanded matrix will contain.
112-
for row_i in range(indptr.shape[0]-1):
113-
# nnz is the number of nonzero elements in this row.
114-
nnz = indptr[row_i + 1] - indptr[row_i]
115-
if degree == 2:
116-
total_nnz += (nnz ** 2 + nnz) / 2 - interaction_only * nnz
117-
else:
118-
total_nnz += ((nnz ** 3 + 3 * nnz ** 2 + 2 * nnz) / 6
119-
- interaction_only * nnz ** 2)
120-
121215
# Make the arrays that will form the CSR matrix of the expansion.
122-
cdef:
123-
DATA_T[:] expanded_data = np.empty(
124-
shape=total_nnz, dtype=data.base.dtype
125-
)
126-
cnp.int32_t[:] expanded_indices = np.empty(
127-
shape=total_nnz, dtype=np.int32
128-
)
129-
cnp.int32_t num_rows = indptr.shape[0] - 1
130-
cnp.int32_t[:] expanded_indptr = np.empty(
131-
shape=num_rows + 1, dtype=np.int32
132-
)
133-
134-
cnp.int32_t expanded_index = 0, row_starts, row_ends
135-
cnp.int32_t i, j, k, i_ptr, j_ptr, k_ptr, num_cols_in_row
136-
216+
cdef INDEX_A_t row_i, row_starts, row_ends, i, j, k, i_ptr, j_ptr, k_ptr
217+
cdef INDEX_B_t expanded_index=0, num_cols_in_row, col
137218
with nogil:
138-
expanded_indptr[0] = indptr[0]
219+
result_indptr[0] = indptr[0]
139220
for row_i in range(indptr.shape[0]-1):
140221
row_starts = indptr[row_i]
141222
row_ends = indptr[row_i + 1]
@@ -145,24 +226,32 @@ def _csr_polynomial_expansion(
145226
for j_ptr in range(i_ptr + interaction_only, row_ends):
146227
j = indices[j_ptr]
147228
if degree == 2:
148-
col = _deg2_column(d, i, j, interaction_only)
149-
expanded_indices[expanded_index] = col
150-
expanded_data[expanded_index] = (
151-
data[i_ptr] * data[j_ptr])
229+
col = <INDEX_B_t> _deg2_column(
230+
n_features,
231+
i, j,
232+
interaction_only
233+
)
234+
result_indices[expanded_index] = col
235+
result_data[expanded_index] = (
236+
data[i_ptr] * data[j_ptr]
237+
)
152238
expanded_index += 1
153239
num_cols_in_row += 1
154240
else:
155241
# degree == 3
156242
for k_ptr in range(j_ptr + interaction_only, row_ends):
157243
k = indices[k_ptr]
158-
col = _deg3_column(d, i, j, k, interaction_only)
159-
expanded_indices[expanded_index] = col
160-
expanded_data[expanded_index] = (
161-
data[i_ptr] * data[j_ptr] * data[k_ptr])
244+
col = <INDEX_B_t> _deg3_column(
245+
n_features,
246+
i, j, k,
247+
interaction_only
248+
)
249+
result_indices[expanded_index] = col
250+
result_data[expanded_index] = (
251+
data[i_ptr] * data[j_ptr] * data[k_ptr]
252+
)
162253
expanded_index += 1
163254
num_cols_in_row += 1
164255

165-
expanded_indptr[row_i+1] = expanded_indptr[row_i] + num_cols_in_row
166-
167-
return csr_matrix((expanded_data, expanded_indices, expanded_indptr),
168-
shape=(num_rows, expanded_dimensionality))
256+
result_indptr[row_i+1] = result_indptr[row_i] + num_cols_in_row
257+
return

0 commit comments

Comments
 (0)
0