8000 WIP: Add support for scikit-learn 0.22 · sebp/scikit-survival@c374789 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit c374789

Browse files
committed
WIP: Add support for scikit-learn 0.22
- Deprecate presort (scikit-learn/scikit-learn#14907) - Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887) - Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
1 parent 543f976 commit c374789

File tree

7 files changed

+204
-69
lines changed

7 files changed

+204
-69
lines changed

requirements/prod.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ numpy
66
osqp !=0.6.0,!=0.6.1
77
pandas >=0.21,<0.26
88
scipy >=1.0,!=1.3.0
9-
scikit-learn >=0.21.0,<0.22
9+
scikit-learn >=0.22.0,<0.23

sksurv/ensemble/boosting.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# You should have received a copy of the GNU General Public License
1212
# along with this program. If not, see <http://www.gnu.org/licenses/>.
1313
import numbers
14+
import warnings
1415

1516
import numpy
1617

@@ -473,11 +474,8 @@ class GradientBoostingSurvivalAnalysis(BaseGradientBoosting, SurvivalAnalysisMix
473474
Best nodes are defined as relative reduction in impurity.
474475
If None then unlimited number of leaf nodes.
475476
476-
presort : bool or 'auto', optional, default: 'auto'
477-
Whether to presort the data to speed up the finding of best splits in
478-
fitting. Auto mode by default will use presorting on dense data and
479-
default to normal sorting on sparse data. Setting presort to true on
480-
sparse data will raise an error.
477+
presort : deprecated, optional, default: 'deprecated'
478+
This parameter is deprecated and will be removed in a future version.
481479
482480
subsample : float, optional, default: 1.0
483481
The fraction of samples to be used for fitting the individual regression
@@ -498,6 +496,10 @@ class GradientBoostingSurvivalAnalysis(BaseGradientBoosting, SurvivalAnalysisMix
498496
once in a while (the more trees the lower the frequency). If greater
499497
than 1 then it prints progress and performance for every tree.
500498
499+
ccp_alpha : non-negative float, optional, default: 0.0.
500+
Complexity parameter used for Minimal Cost-Complexity Pruning. The
501+
subtree with the largest cost complexity that is smaller than
502+
``ccp_alpha`` will be chosen. By default, no pruning is performed.
501503
502504
Attributes
503505
----------
@@ -543,9 +545,10 @@ def __init__(self, loss="coxph", learning_rate=0.1, n_estimators=100,
543545
max_depth=3, min_impurity_split=None,
544546
min_impurity_decrease=0., random_state=None,
545547
max_features=None, max_leaf_nodes=None,
546-
presort='auto',
548+
presort='deprecated',
547549
subsample=1.0, dropout_rate=0.0,
548-
verbose=0):
550+
verbose=0,
551+
ccp_alpha=0.0):
549552
super().__init__(loss=loss,
550553
learning_rate=learning_rate,
551554
n_estimators=n_estimators,
@@ -562,7 +565,8 @@ def __init__(self, loss="coxph", learning_rate=0.1, n_estimators=100,
562565
max_features=max_features,
563566
max_leaf_nodes=max_leaf_nodes,
564567
presort=presort,
565-
verbose=verbose)
568+
verbose=verbose,
569+
ccp_alpha=ccp_alpha)
566570
self.dropout_rate = dropout_rate
567571

568572
def _check_params(self):
@@ -594,10 +598,11 @@ def _check_params(self):
594598

595599
self.max_features_ = max_features
596600

597-
allowed_presort = ('auto', True, False)
598-
if self.presort not in allowed_presort:
599-
raise ValueError("'presort' should be in {}. Got {!r} instead."
600-
.format(allowed_presort, self.presort))
601+
if self.presort != 'deprecated':
602+
warnings.warn("The parameter 'presort' is deprecated and has no "
603+
"effect. It will be removed in v0.24. You can "
604+
"suppress this warning by not passing any value "
605+
"to the 'presort' parameter.", DeprecationWarning)
601606

602607
if self.loss not in LOSS_FUNCTIONS:
603608
raise ValueError("Loss {!r} not supported.".format(self.loss))
@@ -835,20 +840,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
835840
# The rng state must be preserved if warm_start is True
836841
self._rng = check_random_state(self.random_state)
837842

838-
if self.presort is True and issparse(X):
839-
raise ValueError(
840-
"Presorting is not supported for sparse matrices.")
841-
842-
presort = self.presort
843-
# Allow presort to be 'auto', which means True if the dataset is dense,
844-
# otherwise it will be False.
845-
if presort == 'auto':
846-
presort = not issparse(X)
847-
848843
X_idx_sorted = None
849-
if presort:
850-
X_idx_sorted = numpy.asfortranarray(numpy.argsort(X, axis=0),
851-
dtype=numpy.int32)
852844

853845
# fit the boosting stages
854846
y = numpy.fromiter(zip(event, time), dtype=[('event', numpy.bool), ('time', numpy.float64)])

sksurv/ensemble/forest.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import warnings
33
import numpy as np
44
from sklearn.ensemble.base import _partition_estimators
5-
from sklearn.ensemble.forest import BaseForest, _accumulate_prediction, \
6-
_generate_unsampled_indices, _parallel_build_trees
5+
from sklearn.ensemble._forest import BaseForest, _accumulate_prediction, \
6+
_generate_unsampled_indices, _get_n_samples_bootstrap, _parallel_build_trees
77
from sklearn.tree._tree import DTYPE
88
from sklearn.utils._joblib import Parallel, delayed
99
from sklearn.utils.fixes import _joblib_parallel_args
@@ -117,6 +117,11 @@ class RandomSurvivalForest(BaseForest, SurvivalAnalysisMixin):
117117
and add more estimators to the ensemble, otherwise, just fit a whole
118118
new forest.
119119
120+
ccp_alpha : non-negative float, optional, default: 0.0.
121+
Complexity parameter used for Minimal Cost-Complexity Pruning. The
122+
subtree with the largest cost complexity that is smaller than
123+
``ccp_alpha`` will be chosen. By default, no pruning is performed.
124+
120125
Attributes
121126
----------
122127
estimators_ : list of SurvivalTree instances
@@ -177,7 +182,9 @@ def __init__(self,
177182
n_jobs=None,
178183
random_state=None,
179184
verbose=0,
180-
warm_start=False):
185+
warm_start=False,
186+
ccp_alpha=0.0,
187+
max_samples=None):
181188
super().__init__(
182189
base_estimator=SurvivalTree(),
183190
n_estimators=n_estimators,
@@ -187,20 +194,23 @@ def __init__(self,
187194
"min_weight_fraction_leaf",
188195
"max_features",
189196
"max_leaf_nodes",
190-
"random_state"),
197+
"random_state",
198+
"ccp_alpha"),
191199
bootstrap=bootstrap,
192200
oob_score=oob_score,
193201
n_jobs=n_jobs,
194202
random_state=random_state,
195203
verbose=verbose,
196-
warm_start=warm_start)
204+
warm_start=warm_start,
205+
max_samples=max_samples)
197206

198207
self.max_depth = max_depth
199208
self.min_samples_split = min_samples_split
200209
self.min_samples_leaf = min_samples_leaf
201210
self.min_weight_fraction_leaf = min_weight_fraction_leaf
202211
self.max_features = max_features
203212
self.max_leaf_nodes = max_leaf_nodes
213+
self.ccp_alpha = ccp_alpha
204214

205215
@property
206216
def feature_importances_(self):
@@ -234,6 +244,12 @@ def fit(self, X, y, sample_weight=None):
234244
y_numeric[:, 0] = time.astype(np.float64)
235245
y_numeric[:, 1] = event.astype(np.float64)
236246

247+
# Get bootstrap sample size
248+
n_samples_bootstrap = _get_n_samples_bootstrap(
249+
n_samples=X.shape[0],
250+
max_samples=self.max_samples
251+
)
252+
237253
# Check parameters
238254
self._validate_estimator()
239255

@@ -277,7 +293,8 @@ def fit(self, X, y, sample_weight=None):
277293
**_joblib_parallel_args(prefer='threads'))(
278294
delayed(_parallel_build_trees)(
279295
t, self, X, (y_numeric, self.event_times_), sample_weight, i, len(trees),
280-
verbose=self.verbose)
296+
verbose=self.verbose,
297+
n_samples_bootstrap=n_samples_bootstrap)
281298
for i, t in enumerate(trees))
282299

283300
# Collect newly grown trees
@@ -298,9 +315,13 @@ def _set_oob_score(self, X, y):
298315
predictions = np.zeros(n_samples)
299316
n_predictions = np.zeros(n_samples)
300317

318+
n_samples_bootstrap = _get_n_samples_bootstrap(
319+
n_samples, self.max_samples
320+
)
321+
301322
for estimator in self.estimators_:
302323
unsampled_indices = _generate_unsampled_indices(
303-
estimator.random_state, n_samples)
324+
estimator.random_state, n_samples, n_samples_bootstrap)
304325
p_estimator = estimator.predict(
305326
X[unsampled_indices, :], check_input=False)
306327

sksurv/tree/tree.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from math import ceil
22
import numbers
3+
import warnings
34
import numpy as np
4-
from sklearn.base import BaseEstimator
5+
from sklearn.base import BaseEstimator, clone
56
from sklearn.tree import _tree
67
from sklearn.tree._splitter import Splitter
78
from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
9+
from sklearn.tree._tree import _build_pruned_tree_ccp
10+
from sklearn.tree._tree import ccp_pruning_path
811
from sklearn.tree.tree import DENSE_SPLITTERS
12+
from sklearn.utils import Bunch
913
from sklearn.utils.validation import check_array, check_is_fitted, check_random_state
1014

1115
from ..base import SurvivalAnalysisMixin
@@ -89,12 +93,13 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
8993
Best nodes are defined as relative reduction in impurity.
9094
If None then unlimited number of leaf nodes.
9195
92-
presort : bool, optional, default: False
93-
Whether to presort the data to speed up the finding of best splits in
94-
fitting. For the default settings of a decision tree on large
95-
datasets, setting this to true may slow down the training process.
96-
When using either a smaller dataset or a restricted depth, this may
97-
speed up the training.
96+
presort : deprecated, optional, default: 'deprecated'
97+
This parameter is deprecated and will be removed in a future version.
98+
99+
ccp_alpha : non-negative float, optional, default: 0.0.
100+
Complexity parameter used for Minimal Cost-Complexity Pruning. The
101+
subtree with the largest cost complexity that is smaller than
102+
``ccp_alpha`` will be chosen. By default, no pruning is performed.
98103
99104
Attributes
100105
----------
@@ -132,7 +137,8 @@ def __init__(self,
132137
max_features=None,
133138
random_state=None,
134139
max_leaf_nodes=None,
135-
presort=False):
140+
presort='deprecated',
141+
ccp_alpha=0.0):
136142
self.splitter = splitter
137143
self.max_depth = max_depth
138144
self.min_samples_split = min_samples_split
@@ -142,6 +148,7 @@ def __init__(self,
142148
self.random_state = random_state
143149
self.max_leaf_nodes = max_leaf_nodes
144150
self.presort = presort
151+
self.ccp_alpha = ccp_alpha
145152

146153
def fit(self, X, y, sample_weight=None, check_input=True,
147154
X_idx_sorted=None):
@@ -186,10 +193,6 @@ def fit(self, X, y, sample_weight=None, check_input=True,
186193
n_samples, self.n_features_ = X.shape
187194
params = self._check_params(n_samples)
188195

189-
if params["presort"]:
190-
X_idx_sorted = np.asfortranarray(np.argsort(X, axis=0),
191-
dtype=np.int32)
192-
193196
self.n_outputs_ = self.event_times_.shape[0]
194197
# one "class" for CHF, one for survival function
195198
self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2
@@ -204,8 +207,7 @@ def fit(self, X, y, sample_weight=None, check_input=True,
204207
self.max_features_,
205208
params["min_samples_leaf"],
206209
params["min_weight_leaf"],
207-
random_state,
208-
self.presort)
210+
random_state)
209211

210212
self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_)
211213

@@ -230,8 +232,59 @@ def fit(self, X, y, sample_weight=None, check_input=True,
230232

231233
builder.build(self.tree_, X, y_numeric, sample_weight, X_idx_sorted)
232234

235+
self._prune_tree()
236+
233237
return self
234238

239+
def _prune_tree(self):
240+
"""Prune tree using Minimal Cost-Complexity Pruning."""
241+
check_is_fitted(self)
242+
243+
if self.ccp_alpha < 0.0:
244+
raise ValueError("ccp_alpha must be greater than or equal to 0")
245+
246+
if self.ccp_alpha == 0.0:
247+
return
248+
249+
# build pruned treee
250+
n_classes = np.atleast_1d(self.n_classes_)
251+
pruned_tree = Tree(self.n_features_, n_classes, self.n_outputs_)
252+
_build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha)
253+
254+
self.tree_ = pruned_tree
255+
256+
def cost_complexity_pruning_path(self, X, y, sample_weight=None):
257+
"""Compute the pruning path during Minimal Cost-Complexity Pruning.
258+
See `ref`:minimal_cost_complexity_pruning` for details on the pruning
259+
process.
260+
Parameters
261+
----------
262+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
263+
The training input samples. Internally, it will be converted to
264+
``dtype=np.float32`` and if a sparse matrix is provided
265+
to a sparse ``csc_matrix``.
266+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
267+
The target values (class labels) as integers or strings.
268+
sample_weight : array-like of shape (n_samples,), default=None
269+
Sample weights. If None, then samples are equally weighted. Splits
270+
that would create child nodes with net zero or negative weight are
271+
ignored while searching for a split in each node. Splits are also
272+
ignored if they would result in any single class carrying a
273+
negative weight in either child node.
274+
Returns
275+
-------
276+
ccp_path : Bunch
277+
Dictionary-like object, with attributes:
278+
ccp_alphas : ndarray
279+
Effective alphas of subtree during pruning.
280+
impurities : ndarray
281+
Sum of the impurities of the subtree leaves for the
282+
corresponding alpha value in ``ccp_alphas``.
283+
"""
284+
est = clone(self).set_params(ccp_alpha=0.0)
285+
est.fit(X, y, sample_weight=sample_weight)
286+
return Bunch(**ccp_pruning_path(est.tree_))
287+
235288
def _check_params(self, n_samples):
236289
# Check parameters
237290
max_depth = ((2 ** 31) - 1 if self.max_depth is None
@@ -252,11 +305,14 @@ def _check_params(self, n_samples):
252305
min_weight_leaf = self.min_weight_fraction_leaf * n_samples
253306
min_impurity_split = 1e-7
254307

255-
allowed_presort = ('auto', True, False)
256-
if self.presort not in allowed_presort:
257-
raise ValueError("'presort' should be in {}. Got {!r} instead."
258-
.format(allowed_presort, self.presort))
259-
presort = True if self.presort == 'auto' else self.presort
308+
if self.presort != 'deprecated':
309+
warnings.warn("The parameter 'presort' is deprecated and has no "
310+
"effect. It will be removed in v0.24. You can "
311+
"suppress this warning by not passing any value "
312+
"to the 'presort' parameter.", DeprecationWarning)
313+
314+
if self.ccp_alpha < 0.0:
315+
raise ValueError("ccp_alpha must be greater than or equal to 0")
260316

261317
return {
262318
"max_depth": max_depth,
@@ -265,7 +321,6 @@ def _check_params(self, n_samples):
265321
"min_samples_split": min_samples_split,
266322
"min_impurity_split": min_impurity_split,
267323
"min_weight_leaf": min_weight_leaf,
268-
"presort": presort,
269324
}
270325

271326
def _check_max_leaf_nodes(self):

0 commit comments

Comments
 (0)
0