8000 Merge pull request #1668 from glouppe/adaboost-tree · seckcoder/scikit-learn@b126421 · GitHub
[go: up one dir, main page]

Skip to content

Commit b126421

Browse files
committed
Merge pull request scikit-learn#1668 from glouppe/adaboost-tree
[MRG] Precompute X_argsorted in AdaBoost
2 parents 135636e + 759e6fa commit b126421

File tree

1 file changed

+60
-15
lines changed

1 file changed

+60
-15
lines changed

sklearn/ensemble/weight_boosting.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .base import BaseEnsemble
2828
from ..base import ClassifierMixin, RegressorMixin
2929
from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
30+
from ..tree.tree import BaseDecisionTree
3031
from ..utils import check_arrays, check_random_state
3132
from ..metrics import accuracy_score, r2_score
3233

@@ -108,12 +109,20 @@ def fit(self, X, y, sample_weight=None):
108109
self.estimator_weights_ = np.zeros(self.n_estimators, dtype=np.float)
109110
self.estimator_errors_ = np.ones(self.n_estimators, dtype=np.float)
110111

112+
# Create argsorted X for fast tree induction
113+
X_argsorted = None
114+
115+
if isinstance(self.base_estimator, BaseDecisionTree):
116+
X_argsorted = np.asfortranarray(
117+
np.argsort(X.T, axis=1).astype(np.int32).T)
118+
111119
for iboost in xrange(self.n_estimators):
112120
# Boosting step
113121
sample_weight, estimator_weight, estimator_error = self._boost(
114122
iboost,
115123
X, y,
116-
sample_weight)
124+
sample_weight,
125+
X_argsorted=X_argsorted)
117126

118127
# Early termination
119128
if sample_weight is None:
@@ -139,7 +148,7 @@ def fit(self, X, y, sample_weight=None):
139148
return self
140149

141150
@abstractmethod
142-
def _boost(self, iboost, X, y, sample_weight):
151+
def _boost(self, iboost, X, y, sample_weight, X_argsorted=None):
143152
"""Implement a single boost.
144153
145154
Warning: This method needs to be overriden by subclasses.
@@ -158,6 +167,14 @@ def _boost(self, iboost, X, y, sample_weight):
158167
sample_weight : array-like of shape = [n_samples]
159168
The current sample weights.
160169
170+
X_argsorted : array-like, shape = [n_samples, n_features] (optional)
171+
Each column of ``X_argsorted`` holds the row indices of ``X``
172+
sorted according to the value of the corresponding feature
173+
in ascending order.
174+
The argument is supported to enable multiple decision trees
175+
to share the data structure and to avoid re-computation in
176+
tree ensembles. For maximum efficiency use dtype np.int32.
177+
161178
Returns
162179
-------
163180
sample_weight : array-like of shape = [n_samples] or None
@@ -367,7 +384,7 @@ def fit(self, X, y, sample_weight=None):
367384

368385
return super(AdaBoostClassifier, self).fit(X, y, sample_weight)
369386

370-
def _boost(self, iboost, X, y, sample_weight):
387+
def _boost(self, iboost, X, y, sample_weight, X_argsorted=None):
371388
"""Implement a single boost.
372389
373390
Perform a single boost according to the real multi-class SAMME.R
@@ -388,6 +405,14 @@ def _boost(self, iboost, X, y, sample_weight):
388405
sample_weight : array-like of shape = [n_samples]
389406
The current sample weights.
390407
408+
X_argsorted : array-like, shape = [n_samples, n_features] (optional)
409+
Each column of ``X_argsorted`` holds the row indices of ``X``
410+
sorted according to the value of the corresponding feature
411+
in ascending order.
412+
The argument is supported to enable multiple decision trees
413+
to share the data structure and to avoid re-computation in
414+
tree ensembles. For maximum efficiency use dtype np.int32.
415+
391416
Returns
392417
-------
393418
sample_weight : array-like of shape = [n_samples] or None
@@ -403,17 +428,24 @@ def _boost(self, iboost, X, y, sample_weight):
403428
If None then boosting has terminated early.
404429
"""
405430
if self.algorithm == 'SAMME.R':
406-
return self._boost_real(iboost, X, y, sample_weight)
431+
return self._boost_real(iboost, X, y, sample_weight,
432+
X_argsorted=X_argsorted)
407433

408434
else: # elif self.algorithm == "SAMME":
409-
return self._boost_discrete(iboost, X, y, sample_weight)
435+
return self._boost_discrete(iboost, X, y, sample_weight,
436+
X_argsorted=X_argsorted)
410437

411-
def _boost_real(self, iboost, X, y, sample_weight):
438+
def _boost_real(self, iboost, X, y, sample_weight, X_argsorted=None):
412439
"""Implement a single boost using the SAMME.R real algorithm."""
413440
estimator = self._make_estimator()
414441

415-
y_predict_proba = estimator.fit(
416-
X, y, sample_weight=sample_weight).predict_proba(X)
442+
if X_argsorted is not None:
443+
estimator.fit( 6D4E X, y, sample_weight=sample_weight,
444+
X_argsorted=X_argsorted)
445+
else:
446+
estimator.fit(X, y, sample_weight=sample_weight)
447+
448+
y_predict_proba = estimator.predict_proba(X)
417449

418450
if iboost == 0:
419451
self.classes_ = getattr(estimator, 'classes_', None)
@@ -464,12 +496,17 @@ def _boost_real(self, iboost, X, y, sample_weight):
464496

465497
return sample_weight, 1., estimator_error
466498

467-
def _boost_discrete(self, iboost, X, y, sample_weight):
499+
def _boost_discrete(self, iboost, X, y, sample_weight, X_argsorted=None):
468500
"""Implement a single boost using the SAMME discrete algorithm."""
469501
estimator = self._make_estimator()
470502

471-
y_predict = estimator.fit(
472-
X, y, sample_weight=sample_weight).predict(X)
503+
if X_argsorted is not None:
504+
estimator.fit(X, y, sample_weight=sample_weight,
505+
X_argsorted=X_argsorted)
506+
else:
507+
estimator.fit(X, y, sample_weight=sample_weight)
508+
509+
y_predict = estimator.predict(X)
473510

474511
if iboost == 0:
475512
self.classes_ = getattr(estimator, 'classes_', None)
@@ -875,7 +912,7 @@ def fit(self, X, y, sample_weight=None):
875912
# Fit
876913
return super(AdaBoostRegressor, self).fit(X, y, sample_weight)
877914

878-
def _boost(self, iboost, X, y, sample_weight):
915+
def _boost(self, iboost, X, y, sample_weight, X_argsorted=None):
879916
"""Implement a single boost for regression
880917
881918
Perform a single boost according to the AdaBoost.R2 algorithm and
@@ -896,6 +933,14 @@ def _boost(self, iboost, X, y, sample_weight):
896933
sample_weight : array-like of shape = [n_samples]
897934
The current sample weights.
898935
936+
X_argsorted : array-like, shape = [n_samples, n_features] (optional)
937+
Each column of ``X_argsorted`` holds the row indices of ``X``
938+
sorted according to the value of the corresponding feature
939+
in ascending order.
940+
The argument is supported to enable multiple decision trees
941+
to share the data structure and to avoid re-computation in
942+
tree ensembles. For maximum efficiency use dtype np.int32.
943+
899944
Returns
900945
-------
901946
sample_weight : array-like of shape = [n_samples] or None
@@ -925,8 +970,9 @@ def _boost(self, iboost, X, y, sample_weight):
925970

926971
# Fit on the bootstrapped sample and obtain a prediction
927972
# for all samples in the training set
928-
y_predict = estimator.fit(
929-
X[bootstrap_idx], y[bootstrap_idx]).predict(X)
973+
# X_argsorted is not used since bootstrap copies are used.
974+
estimator.fit(X[bootstrap_idx], y[bootstrap_idx])
975+
y_predict = estimator.predict(X)
930976

931977
error_vect = np.abs(y_predict - y)
932978
error_max = error_vect.max()
@@ -965,7 +1011,6 @@ def _boost(self, iboost, X, y, sample_weight):
9651011
return sample_weight, estimator_weight, estimator_error
9661012

9671013
def _get_median_predict(self, X, limit=-1):
968-
9691014
if not self.estimators_:
9701015
raise RuntimeError(
9711016
("{0} is not initialized. "

0 commit comments

Comments
 (0)
0