8000 [MRG+2] Classifier chain by adamklec · Pull Request #7602 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] Classifier chain #7602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 96 commits into from
Jun 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
5b9b5d4
initial commit
Sep 27, 2016
bc48173
added shuffle parameter
Oct 7, 2016
ebe0362
fixes for python 3 support
Oct 7, 2016
5b7d36c
fixed formatting issues
Oct 7, 2016
cb12665
more formatting fixes
Oct 7, 2016
5b05105
removed default base_estimator
Oct 10, 2016
0720082
doc string formatting
Oct 10, 2016
073af55
added doctoring to classifier_chain example
Oct 10, 2016
ddf8562
added multi_label to setup.py
Oct 11, 2016
4a53cd4
bug fix
Oct 12, 2016
87afb90
updated classifier chain example
Oct 12, 2016
04bf31a
formatting fix
Oct 12, 2016
be960a7
now using numpy RandomState
Oct 17, 2016
a01e114
random_state fix
Oct 18, 2016
5e0644b
removed unnecessary import
Oct 18, 2016
d3cfd9f
formatting fix
Oct 18, 2016
5a22945
comment fix
Oct 19, 2016
6e38f10
minor comment fix
Oct 24, 2016
0806a40
formatting fix
Oct 25, 2016
8abebc2
pushing again
Nov 11, 2016
d024eda
formatting fix
Nov 11, 2016
89cafad
added documentation
Nov 27, 2016
273ac8a
Merge branch 'classifier_chain' of https://github.com/adamklec/scikit…
Nov 27, 2016
322b56e
requested changes
Dec 27, 2016
cbfc98d
added doc file
Dec 27, 2016
375636b
minor fixes. added support for sparse labels/
Dec 28, 2016
38cb4a6
updated doc strings.
Jan 5, 2017
a84eeb5
update doc string
Jan 6, 2017
92dd7c9
Merge branch 'master2' into classifier_chain
Jan 6, 2017
ad5369f
removed main
Jan 6, 2017
2a142dc
renamed multi_label to multilabel.
Jan 9, 2017
838a815
Merge branch 'master2' into classifier_chain
Jan 9, 2017
9960a71
convert range to list for python 3
Jan 9, 2017
b013da6
fix for passing range as order
Jan 9, 2017
eed45b0
remove old doc strings
Jan 9, 2017
323c13d
fix sparse labels snd formatting
Jan 9, 2017
998a7a1
sparse label fix
Jan 10, 2017
ce91094
label slicing fix
Jan 10, 2017
f7e913c
added performance test and descriptions to test_classifier_chain
Jan 10, 2017
2a5d6d7
requested changes
Jan 16, 2017
2ce7662
minor fixes
Jan 16, 2017
7651518
minor fixes
Jan 16, 2017
6a3cb61
minor fix
Jan 16, 2017
42900ea
merged multilabel into multiclass
Jan 16, 2017
a08b4ac
move ClassifierChain to multioutput
Jan 16, 2017
ab6d6bb
renaming
Jan 16, 2017
0d9412e
more renaming
Jan 16, 2017
97eac6b
minor fixes
Jan 16, 2017
e58dbd8
formatting
Jan 16, 2017
6d06467
added predict_proba and decision_function
Jan 17, 2017
17e3235
minor changes. performance test is failing.
Jan 17, 2017
60e9062
use binary predictions as features in decision_function and predict_p…
Jan 18, 2017
eadc718
formatting fix
Jan 18, 2017
c5bea5c
added classes_ attribute
Jan 18, 2017
b708add
optimization for sparse data. removed support for sparse labels.
Jan 24, 2017
f6b22e2
formatting fix
Jan 24, 2017
57568f8
sparse data fix
Jan 24, 2017
cc88431
requested changes
Feb 27, 2017
00ffd5a
merge master into branch
Feb 27, 2017
a28b20f
sparse fix
Feb 27, 2017
4e9bc10
added support for cross_val_predict to ClassifierChain
Mar 6, 2017
e02835d
formatting fix
Mar 6, 2017
ff4a055
requested changes
Mar 7, 2017
7f27eac
New dataset for testing ClassifierChain.
Mar 9, 2017
7dcf220
added a performance test on the yeast dataset
Mar 9, 2017
77b1ecf
changed performance test on yeast dataset to use a single chain with …
Mar 10, 2017
5d6db7e
removed train test split import from classifier chain performance test
Mar 10, 2017
184a404
requested changes
Mar 10, 2017
dc2171a
update doc string
Mar 10, 2017
27ff734
update random_state doc string
Mar 10, 2017
73dce2b
formatting fix
Mar 10, 2017
4773cb4
requested changes
May 1, 2017
4ceff8a
Merge branch 'master2' into classifier_chain
May 1, 2017
e354fc8
formatting
May 1, 2017
5b78f12
bug fixes
May 1, 2017
4eaa9ac
added a plot to the yeast example
May 1, 2017
a78bab4
formatting
May 1, 2017
a8f2e7e
formatting
May 1, 2017
62091cc
formatting
May 1, 2017
79d686b
formatting
May 2, 2017
363ab74
formatting
May 2, 2017
446c6d6
typo fixes
May 2, 2017
b37fa80
bug fix
May 2, 2017
f60aa8b
requested changes
May 30, 2017
202ebad
make example reproducible
May 30, 2017
c00af50
formatting
May 30, 2017
7ddea73
added check_array to predict methods
May 31, 2017
954afd1
doctoring fix
Jun 5, 2017
1c11ddf
requested changes
Jun 19, 2017
e984f1a
Merge branch 'master2' into classifier_chain
Jun 19, 2017
80b5c53
requested changes to tests
Jun 20, 2017
64c0941
Merge branch 'master2' into classifier_chain
Jun 20, 2017
4c574f8
removed unused import
Jun 20, 2017
8ae0bc3
formatting
Jun 20, 2017
77876d6
updated whats_new.rst
Jun 20, 2017
7eb0467
fixed docstrings
Jun 28, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions doc/modules/multiclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,30 @@ Below is an example of multioutput classification:
[0, 0, 2],
[2, 0, 0]])

Classifier Chain
================

Classifier chains (see :class:`ClassifierChain`) are a way of combining a
number of binary classifiers into a single multi-label model that is capable
of exploiting correlations among targets.

For a multi-label classification problem with N classes, N binary
classifiers are assigned an integer between 0 and N-1. These integers
define the order of models in the chain. Each classifier is then fit on the
available training data plus the true labels of the classes whose
models were assigned a lower number.

When predicting, the true labels will not be available. Instead the
predictions of each model are passed on to the subsequent models in the
chain to be used as features.

Clearly the order of the chain is important. The first model in the chain
has no information about the other labels while the last model in the chain
has features indicating the presence of all of the other labels. In general
one does not know the optimal ordering of the models in the chain so
typically many randomly ordered chains are fit and their predictions are
averaged together.

.. topic:: References:
Jesse Read, Bernhard Pfahringer, Geoff Holmes, Eibe Frank,
"Classifier Chains for Multi-label Classification", 2009.
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ Changelog
New features
............

- Added :class:`multioutput.ClassifierChain` for multi-label
classification. By `Adam Kleczewski <adamklec>`_.

- Validation that input data contains no NaN or inf can now be suppressed
using :func:`config_context`, at your own risk. This will save on runtime,
and may be particularly useful for prediction time. :issue:`7548` by
Expand Down
6 changes: 6 additions & 0 deletions examples/multioutput/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _multioutput_examples:

Multioutput methods
----------------

Examples concerning the :mod:`sklearn.multioutput` module.
110 changes: 110 additions & 0 deletions examples/multioutput/plot_classifier_chain_yeast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
============================
Classifier Chain
============================
Example of using classifier chain on a multilabel dataset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to elaborate on the connections between classifier chains and model stacking. Might it be worth it to also mention that regression chains are possible? Maybe that's better saved for when someone implements it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my preference is to keep the text concise and focused on the example. If there is a specific point you think we should make here I'd consider adding a sentence or two. Otherwise I'm inclined to leave it as is.

I think we should stay away from dealing with regression models in this PR. But I agree that is a good direction to go.

For this example we will use the `yeast
http://mldata.org/repository/data/viewslug/yeast/`_ dataset which
contains 2417 datapoints each with 103 features and 14 possible labels. Each
datapoint has at least one label. As a baseline we first train a logistic
regression classifier for each of the 14 labels. To evaluate the performance
of these classifiers we predict on a held-out test set and calculate the
:ref:`User Guide <jaccard_similarity_score>`.

Next we create 10 classifier chains. Each classifier chain contains a
logistic regression model for each of the 14 labels. The models in each
chain are ordered randomly. In addition to the 103 features in the dataset,
each model gets the predictions of the preceding models in the chain as
features (note that by default at training time each model gets the true
labels as features). These additional features allow each chain to exploit
correlations among the classes. The Jaccard similarity score for each chain
tends to be greater than that of the set independent logistic models.

Because the models in each chain are arranged randomly there is significant
variation in performance among the chains. Presumably there is an optimal
ordering of the classes in a chain that will yield the best performance.
However we do not know that ordering a priori. Instead we can construct an
voting ensemble of classifier chains by averaging the binary predictions of
the chains and apply a threshold of 0.5. The Jaccard similarity score of the
ensemble is greater than that of the independent models and tends to exceed
the score of each chain in the ensemble (although this is not guaranteed
with randomly ordered chains).
"""

print(__doc__)

# Author: Adam Kleczewski
# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt
from sklearn.multioutput import ClassifierChain
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import jaccard_similarity_score
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import fetch_mldata

# Load a multi-label dataset
yeast = fetch_mldata('yeast')
X = yeast['data']
Y = yeast['target'].transpose().toarray()
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2,
random_state=0)

# Fit an independent logistic regression model for each class using the
# OneVsRestClassifier wrapper.
ovr = OneVsRestClassifier(LogisticRegression())
ovr.fit(X_train, Y_train)
Y_pred_ovr = ovr.predict(X_test)
ovr_jaccard_score = jaccard_similarity_score(Y_test, Y_pred_ovr)

# Fit an ensemble of logistic regression classifier chains and take the
# take the average prediction of all the chains.
chains = [ClassifierChain(LogisticRegression(), order='random', random_state=i)
for i in range(10)]
for chain in chains:
chain.fit(X_train, Y_train)

Y_pred_chains = np.array([chain.predict(X_test) for chain in
chains])
chain_jaccard_scores = [jaccard_similarity_score(Y_test, Y_pred_chain >= .5)
for Y_pred_chain in Y_pred_chains]

Y_pred_ensemble = Y_pred_chains.mean(axis=0)
ensemble_jaccard_score = jaccard_similarity_score(Y_test,
Y_pred_ensemble >= .5)

model_scores = [ovr_jaccard_score] + chain_jaccard_scores
model_scores.append(ensemble_jaccard_score)

model_names = ('Independent Models',
'Chain 1',
'Chain 2',
'Chain 3',
'Chain 4',
'Chain 5',
'Chain 6',
'Chain 7',
'Chain 8',
'Chain 9',
'Chain 10',
'Ensemble Average')

y_pos = np.arange(len(model_names))
y_pos[1:] += 1
y_pos[-1] += 1

# Plot the Jaccard similarity scores for the independent model, each of the
# chains, and the ensemble (note that the vertical axis on this plot does
# not begin at 0).

fig = plt.figure(figsize=(7, 4))
plt.title('Classifier Chain Ensemble')
plt.xticks(y_pos, model_names, rotation='vertical')
plt.ylabel('Jaccard Similarity Score')
plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1])
colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g']
plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors)
plt.show()
Loading
0