-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG+2] Classifier chain #7602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[MRG+2] Classifier chain #7602
Changes from all commits
Commits
Show all changes
96 commits
Select commit
Hold shift + click to select a range
5b9b5d4
initial commit
bc48173
added shuffle parameter
ebe0362
fixes for python 3 support
5b7d36c
fixed formatting issues
cb12665
more formatting fixes
5b05105
removed default base_estimator
0720082
doc string formatting
073af55
added doctoring to classifier_chain example
ddf8562
added multi_label to setup.py
4a53cd4
bug fix
87afb90
updated classifier chain example
04bf31a
formatting fix
be960a7
now using numpy RandomState
a01e114
random_state fix
5e0644b
removed unnecessary import
d3cfd9f
formatting fix
5a22945
comment fix
6e38f10
minor comment fix
0806a40
formatting fix
8abebc2
pushing again
d024eda
formatting fix
89cafad
added documentation
273ac8a
Merge branch 'classifier_chain' of https://github.com/adamklec/scikit…
322b56e
requested changes
cbfc98d
added doc file
375636b
minor fixes. added support for sparse labels/
38cb4a6
updated doc strings.
a84eeb5
update doc string
92dd7c9
Merge branch 'master2' into classifier_chain
ad5369f
removed main
2a142dc
renamed multi_label to multilabel.
838a815
Merge branch 'master2' into classifier_chain
9960a71
convert range to list for python 3
b013da6
fix for passing range as order
eed45b0
remove old doc strings
323c13d
fix sparse labels snd formatting
998a7a1
sparse label fix
ce91094
label slicing fix
f7e913c
added performance test and descriptions to test_classifier_chain
2a5d6d7
requested changes
2ce7662
minor fixes
7651518
minor fixes
6a3cb61
minor fix
42900ea
merged multilabel into multiclass
a08b4ac
move ClassifierChain to multioutput
ab6d6bb
renaming
0d9412e
more renaming
97eac6b
minor fixes
e58dbd8
formatting
6d06467
added predict_proba and decision_function
17e3235
minor changes. performance test is failing.
60e9062
use binary predictions as features in decision_function and predict_p…
eadc718
formatting fix
c5bea5c
added classes_ attribute
b708add
optimization for sparse data. removed support for sparse labels.
f6b22e2
formatting fix
57568f8
sparse data fix
cc88431
requested changes
00ffd5a
merge master into branch
a28b20f
sparse fix
4e9bc10
added support for cross_val_predict to ClassifierChain
e02835d
formatting fix
ff4a055
requested changes
7f27eac
New dataset for testing ClassifierChain.
7dcf220
added a performance test on the yeast dataset
77b1ecf
changed performance test on yeast dataset to use a single chain with …
5d6db7e
removed train test split import from classifier chain performance test
184a404
requested changes
dc2171a
update doc string
27ff734
update random_state doc string
73dce2b
formatting fix
4773cb4
requested changes
4ceff8a
Merge branch 'master2' into classifier_chain
e354fc8
formatting
5b78f12
bug fixes
4eaa9ac
added a plot to the yeast example
a78bab4
formatting
a8f2e7e
formatting
62091cc
formatting
79d686b
formatting
363ab74
formatting
446c6d6
typo fixes
b37fa80
bug fix
f60aa8b
requested changes
202ebad
make example reproducible
c00af50
formatting
7ddea73
added check_array to predict methods
954afd1
doctoring fix
1c11ddf
requested changes
e984f1a
Merge branch 'master2' into classifier_chain
80b5c53
requested changes to tests
64c0941
Merge branch 'master2' into classifier_chain
4c574f8
removed unused import
8ae0bc3
formatting
77876d6
updated whats_new.rst
7eb0467
fixed docstrings
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.. _multioutput_examples: | ||
|
||
Multioutput methods | ||
---------------- | ||
|
||
Examples concerning the :mod:`sklearn.multioutput` module. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
""" | ||
============================ | ||
Classifier Chain | ||
============================ | ||
Example of using classifier chain on a multilabel dataset. | ||
|
||
For this example we will use the `yeast | ||
http://mldata.org/repository/data/viewslug/yeast/`_ dataset which | ||
contains 2417 datapoints each with 103 features and 14 possible labels. Each | ||
datapoint has at least one label. As a baseline we first train a logistic | ||
regression classifier for each of the 14 labels. To evaluate the performance | ||
of these classifiers we predict on a held-out test set and calculate the | ||
:ref:`User Guide <jaccard_similarity_score>`. | ||
|
||
Next we create 10 classifier chains. Each classifier chain contains a | ||
logistic regression model for each of the 14 labels. The models in each | ||
chain are ordered randomly. In addition to the 103 features in the dataset, | ||
each model gets the predictions of the preceding models in the chain as | ||
features (note that by default at training time each model gets the true | ||
labels as features). These additional features allow each chain to exploit | ||
correlations among the classes. The Jaccard similarity score for each chain | ||
tends to be greater than that of the set independent logistic models. | ||
|
||
Because the models in each chain are arranged randomly there is significant | ||
variation in performance among the chains. Presumably there is an optimal | ||
ordering of the classes in a chain that will yield the best performance. | ||
However we do not know that ordering a priori. Instead we can construct an | ||
voting ensemble of classifier chains by averaging the binary predictions of | ||
the chains and apply a threshold of 0.5. The Jaccard similarity score of the | ||
ensemble is greater than that of the independent models and tends to exceed | ||
the score of each chain in the ensemble (although this is not guaranteed | ||
with randomly ordered chains). | ||
""" | ||
|
||
print(__doc__) | ||
|
||
# Author: Adam Kleczewski | ||
# License: BSD 3 clause | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from sklearn.multioutput import ClassifierChain | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.multiclass import OneVsRestClassifier | ||
from sklearn.metrics import jaccard_similarity_score | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.datasets import fetch_mldata | ||
|
||
# Load a multi-label dataset | ||
yeast = fetch_mldata('yeast') | ||
X = yeast['data'] | ||
Y = yeast['target'].transpose().toarray() | ||
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2, | ||
random_state=0) | ||
|
||
# Fit an independent logistic regression model for each class using the | ||
# OneVsRestClassifier wrapper. | ||
ovr = OneVsRestClassifier(LogisticRegression()) | ||
ovr.fit(X_train, Y_train) | ||
Y_pred_ovr = ovr.predict(X_test) | ||
ovr_jaccard_score = jaccard_similarity_score(Y_test, Y_pred_ovr) | ||
|
||
# Fit an ensemble of logistic regression classifier chains and take the | ||
# take the average prediction of all the chains. | ||
chains = [ClassifierChain(LogisticRegression(), order='random', random_state=i) | ||
for i in range(10)] | ||
for chain in chains: | ||
chain.fit(X_train, Y_train) | ||
|
||
Y_pred_chains = np.array([chain.predict(X_test) for chain in | ||
chains]) | ||
chain_jaccard_scores = [jaccard_similarity_score(Y_test, Y_pred_chain >= .5) | ||
for Y_pred_chain in Y_pred_chains] | ||
|
||
Y_pred_ensemble = Y_pred_chains.mean(axis=0) | ||
ensemble_jaccard_score = jaccard_similarity_score(Y_test, | ||
Y_pred_ensemble >= .5) | ||
|
||
model_scores = [ovr_jaccard_score] + chain_jaccard_scores | ||
model_scores.append(ensemble_jaccard_score) | ||
|
||
model_names = ('Independent Models', | ||
'Chain 1', | ||
'Chain 2', | ||
'Chain 3', | ||
'Chain 4', | ||
'Chain 5', | ||
'Chain 6', | ||
'Chain 7', | ||
'Chain 8', | ||
'Chain 9', | ||
'Chain 10', | ||
'Ensemble Average') | ||
|
||
y_pos = np.arange(len(model_names)) | ||
y_pos[1:] += 1 | ||
y_pos[-1] += 1 | ||
|
||
# Plot the Jaccard similarity scores for the independent model, each of the | ||
# chains, and the ensemble (note that the vertical axis on this plot does | ||
# not begin at 0). | ||
|
||
fig = plt.figure(figsize=(7, 4)) | ||
plt.title('Classifier Chain Ensemble') | ||
plt.xticks(y_pos, model_names, rotation='vertical') | ||
plt.ylabel('Jaccard Similarity Score') | ||
plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1]) | ||
colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g'] | ||
plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors) | ||
plt.show() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may want to elaborate on the connections between classifier chains and model stacking. Might it be worth it to also mention that regression chains are possible? Maybe that's better saved for when someone implements it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my preference is to keep the text concise and focused on the example. If there is a specific point you think we should make here I'd consider adding a sentence or two. Otherwise I'm inclined to leave it as is.
I think we should stay away from dealing with regression models in this PR. But I agree that is a good direction to go.