1
1
from math import ceil
2
2
import numbers
3
+ import warnings
3
4
import numpy as np
4
- from sklearn .base import BaseEstimator
5
+ from sklearn .base import BaseEstimator , clone
5
6
from sklearn .tree import _tree
6
7
from sklearn .tree ._splitter import Splitter
7
8
from sklearn .tree ._tree import BestFirstTreeBuilder , DepthFirstTreeBuilder , Tree
9
+ from sklearn .tree ._tree import _build_pruned_tree_ccp
10
+ from sklearn .tree ._tree import ccp_pruning_path
8
11
from sklearn .tree .tree import DENSE_SPLITTERS
12
+ from sklearn .utils import Bunch
9
13
from sklearn .utils .validation import check_array , check_is_fitted , check_random_state
10
14
11
15
from ..base import SurvivalAnalysisMixin
@@ -89,12 +93,13 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
89
93
Best nodes are defined as relative reduction in impurity.
90
94
If None then unlimited number of leaf nodes.
91
95
92
- presort : bool, optional, default: False
93
- Whether to presort the data to speed up the finding of best splits in
94
- fitting. For the default settings of a decision tree on large
95
- datasets, setting this to true may slow down the training process.
96
- When using either a smaller dataset or a restricted depth, this may
97
- speed up the training.
96
+ presort : deprecated, optional, default: 'deprecated'
97
+ This parameter is deprecated and will be removed in a future version.
98
+
99
+ ccp_alpha : non-negative float, optional, default: 0.0.
100
+ Complexity parameter used for Minimal Cost-Complexity Pruning. The
101
+ subtree with the largest cost complexity that is smaller than
102
+ ``ccp_alpha`` will be chosen. By default, no pruning is performed.
98
103
99
104
Attributes
100
105
----------
@@ -132,7 +137,8 @@ def __init__(self,
132
137
max_features = None ,
133
138
random_state = None ,
134
139
max_leaf_nodes = None ,
135
- presort = False ):
140
+ presort = 'deprecated' ,
141
+ ccp_alpha = 0.0 ):
136
142
self .splitter = splitter
137
143
self .max_depth = max_depth
138
144
self .min_samples_split = min_samples_split
@@ -142,6 +148,7 @@ def __init__(self,
142
148
self .random_state = random_state
143
149
self .max_leaf_nodes = max_leaf_nodes
144
150
self .presort = presort
151
+ self .ccp_alpha = ccp_alpha
145
152
146
153
def fit (self , X , y , sample_weight = None , check_input = True ,
147
154
X_idx_sorted = None ):
@@ -186,10 +193,6 @@ def fit(self, X, y, sample_weight=None, check_input=True,
186
193
n_samples , self .n_features_ = X .shape
187
194
params = self ._check_params (n_samples )
188
195
189
- if params ["presort" ]:
190
- X_idx_sorted = np .asfortranarray (np .argsort (X , axis = 0 ),
191
- dtype = np .int32 )
192
-
193
196
self .n_outputs_ = self .event_times_ .shape [0 ]
194
197
# one "class" for CHF, one for survival function
195
198
self .n_classes_ = np .ones (self .n_outputs_ , dtype = np .intp ) * 2
@@ -204,8 +207,7 @@ def fit(self, X, y, sample_weight=None, check_input=True,
204
207
self .max_features_ ,
205
208
params ["min_samples_leaf" ],
206
209
params ["min_weight_leaf" ],
207
- random_state ,
208
- self .presort )
210
+ random_state )
209
211
210
212
self .tree_ = Tree (self .n_features_ , self .n_classes_ , self .n_outputs_ )
211
213
@@ -230,8 +232,59 @@ def fit(self, X, y, sample_weight=None, check_input=True,
230
232
231
233
builder .build (self .tree_ , X , y_numeric , sample_weight , X_idx_sorted )
232
234
235
+ self ._prune_tree ()
236
+
233
237
return self
234
238
239
+ def _prune_tree (self ):
240
+ """Prune tree using Minimal Cost-Complexity Pruning."""
241
+ check_is_fitted (self )
242
+
243
+ if self .ccp_alpha < 0.0 :
244
+ raise ValueError ("ccp_alpha must be greater than or equal to 0" )
245
+
246
+ if self .ccp_alpha == 0.0 :
247
+ return
248
+
249
+ # build pruned treee
250
+ n_classes = np .atleast_1d (self .n_classes_ )
251
+ pruned_tree = Tree (self .n_features_ , n_classes , self .n_outputs_ )
252
+ _build_pruned_tree_ccp (pruned_tree , self .tree_ , self .ccp_alpha )
253
+
254
+ self .tree_ = pruned_tree
255
+
256
+ def cost_complexity_pruning_path (self , X , y , sample_weight = None ):
257
+ """Compute the pruning path during Minimal Cost-Complexity Pruning.
258
+ See `ref`:minimal_cost_complexity_pruning` for details on the pruning
259
+ process.
260
+ Parameters
261
+ ----------
262
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
263
+ The training input samples. Internally, it will be converted to
264
+ ``dtype=np.float32`` and if a sparse matrix is provided
265
+ to a sparse ``csc_matrix``.
266
+ y : array-like of shape (n_samples,) or (n_samples, n_outputs)
267
+ The target values (class labels) as integers or strings.
268
+ sample_weight : array-like of shape (n_samples,), default=None
269
+ Sample weights. If None, then samples are equally weighted. Splits
270
+ that would create child nodes with net zero or negative weight are
271
+ ignored while searching for a split in each node. Splits are also
272
+ ignored if they would result in any single class carrying a
273
+ negative weight in either child node.
274
+ Returns
275
+ -------
276
+ ccp_path : Bunch
277
+ Dictionary-like object, with attributes:
278
+ ccp_alphas : ndarray
279
+ Effective alphas of subtree during pruning.
280
+ impurities : ndarray
281
+ Sum of the impurities of the subtree leaves for the
282
+ corresponding alpha value in ``ccp_alphas``.
283
+ """
284
+ est = clone (self ).set_params (ccp_alpha = 0.0 )
285
+ est .fit (X , y , sample_weight = sample_weight )
286
+ return Bunch (** ccp_pruning_path (est .tree_ ))
287
+
235
288
def _check_params (self , n_samples ):
236
289
# Check parameters
237
290
max_depth = ((2 ** 31 ) - 1 if self .max_depth is None
@@ -252,11 +305,14 @@ def _check_params(self, n_samples):
252
305
min_weight_leaf = self .min_weight_fraction_leaf * n_samples
253
306
min_impurity_split = 1e-7
254
307
255
- allowed_presort = ('auto' , True , False )
256
- if self .presort not in allowed_presort :
257
- raise ValueError ("'presort' should be in {}. Got {!r} instead."
258
- .format (allowed_presort , self .presort ))
259
- presort = True if self .presort == 'auto' else self .presort
308
+ if self .presort != 'deprecated' :
309
+ warnings .warn ("The parameter 'presort' is deprecated and has no "
310
+ "effect. It will be removed in v0.24. You can "
311
+ "suppress this warning by not passing any value "
312
+ "to the 'presort' parameter." , DeprecationWarning )
313
+
314
+ if self .ccp_alpha < 0.0 :
315
+ raise ValueError ("ccp_alpha must be greater than or equal to 0" )
260
316
261
317
return {
262
318
"max_depth" : max_depth ,
@@ -265,7 +321,6 @@ def _check_params(self, n_samples):
265
321
"min_samples_split" : min_samples_split ,
266
322
"min_impurity_split" : min_impurity_split ,
267
323
"min_weight_leaf" : min_weight_leaf ,
268
- "presort" : presort ,
269
324
}
270
325
271
326
def _check_max_leaf_nodes (self ):
0 commit comments