8000 [MRG] Use _check_sample_weight in BaseForest by ritalulu · Pull Request #15492 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Use _check_sample_weight in BaseForest #15492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions sklearn/ensemble/_forest.py
< 8000 td class="blob-code blob-code-context js-file-line">
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
from ._base import BaseEnsemble, _partition_estimators
from ..utils.fixes import _joblib_parallel_args
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, _check_sample_weight

__all__ = ["RandomForestClassifier",
Expand Down Expand Up @@ -249,8 +249,7 @@ def decision_path(self, X):
X = self._validate_X_predict(X)
indicators = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
**_joblib_parallel_args(prefer='threads'))(
delayed(tree.decision_path)(X,
check_input=False)
delayed(tree.decision_path)(X, check_input=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing all these cosmetic issues, but it adds some noise both to the review process and to the history of this file, and it may lead to unnecessary conflicts with other pull-requests.
Therefore, we usually prefer changing only the strict minimum.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok to merge or do you want to undo these?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong feelings, just wanted to mention it to avoid it in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's what I figured, just wanted to confirm :)

for tree in self.estimators_)

n_nodes = [0]
Expand Down Expand Up @@ -288,7 +287,7 @@ def fit(self, X, y, sample_weight=None):
X = check_array(X, accept_sparse="csc", dtype=DTYPE)
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
if sample_weight is not None:
sample_weight = check_array(sample_weight, ensure_2d=False)
sample_weight = _check_sample_weight(sample_weight, X)
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
Expand Down Expand Up @@ -538,7 +537,8 @@ def _validate_y_class_weight(self, y):

y_store_unique_indices = np.zeros(y.shape, dtype=np.int)
for k in range(self.n_outputs_):
classes_k, y_store_unique_indices[:, k] = np.unique(y[:, k], return_inverse=True)
classes_k, y_store_unique_indices[:, k] = \
np.unique(y[:, k], return_inverse=True)
self.classes_.append(classes_k)
self.n_classes_.append(classes_k.shape[0])
y = y_store_unique_indices
Expand All @@ -548,16 +548,18 @@ def _validate_y_class_weight(self, y):
if isinstance(self.class_weight, str):
if self.class_weight not in valid_presets:
raise ValueError('Valid presets for class_weight include '
'"balanced" and "balanced_subsample". Given "%s".'
'"balanced" and "balanced_subsample".'
'Given "%s".'
% self.class_weight)
if self.warm_start:
warn('class_weight presets "balanced" or "balanced_subsample" are '
warn('class_weight presets "balanced" or '
'"balanced_subsample" are '
'not recommended for warm_start if the fitted data '
'differs from the full dataset. In order to use '
'"balanced" weights, use compute_class_weight("balanced", '
'classes, y). In place of y you can use a large '
'enough sample of the full training set target to '
'properly estimate the class frequency '
'"balanced" weights, use compute_class_weight '
'("balanced", classes, y). In place of y you can use '
'a large enough sample of the full training set '
'target to properly estimate the class frequency '
'distributions. Pass the resulting weights as the '
'class_weight parameter.')

Expand Down Expand Up @@ -615,9 +617,9 @@ def predict_proba(self, X):
"""Predict class probabilities for X.

The predicted class probabilities of an input sample are computed as
the mean predicted class probabilities of the trees in the forest. The
class probability of a single tree is the fraction of samples of the same
class in a leaf.
the mean predicted class probabilities of the trees in the forest.
The class probability of a single tree is the fraction of samples of
the same class in a leaf.

Parameters
----------
Expand Down Expand Up @@ -1559,8 +1561,9 @@ class ExtraTreesClassifier(ForestClassifier):
weights inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``

The "balanced_subsample" mode is the same as "balanced" except that weights are
computed based on the bootstrap sample for every tree grown.
The "balanced_subsample" mode is the same as "balanced" except that
weights are computed based on the bootstrap sample for every tree
grown.

For multi-output, the weights of each column of y will be multiplied.

Expand Down
0