8000 [MRG+1] Classifier chain (#7602) · jwjohnson314/scikit-learn@8dbc050 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8dbc050

Browse files
Adam KleczewskiJeremiah Johnson
authored andcommitted
[MRG+1] Classifier chain (scikit-learn#7602)
[MRG+2] Classifier chain
1 parent ed5bebc commit 8dbc050

File tree

8 files changed

+549
-11
lines changed

8 files changed

+549
-11
lines changed

doc/modules/multiclass.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,30 @@ Below is an example of multioutput classification:
348348
[0, 0, 2],
349349
[2, 0, 0]])
350350

351+
Classifier Chain
352+
================
353+
354+
Classifier chains (see :class:`ClassifierChain`) are a way of combining a
355+
number of binary classifiers into a single multi-label model that is capable
356+
of exploiting correlations among targets.
357+
358+
For a multi-label classification problem with N classes, N binary
359+
classifiers are assigned an integer between 0 and N-1. These integers
360+
define the order of models in the chain. Each classifier is then fit on the
361+
available training data plus the true labels of the classes whose
362+
models were assigned a lower number.
363+
364+
When predicting, the true labels will not be available. Instead the
365+
predictions of each model are passed on to the subsequent models in the
366+
chain to be used as features.
367+
368+
Clearly the order of the chain is important. The first model in the chain
369+
has no information about the other labels while the last model in the chain
370+
has features indicating the presence of all of the other labels. In general
371+
one does not know the optimal ordering of the models in the chain so
372+
typically many randomly ordered chains are fit and their predictions are
373+
averaged together.
374+
375+
.. topic:: References:
376+
Jesse Read, Bernhard Pfahringer, Geoff Holmes, Eibe Frank,
377+
"Classifier Chains for Multi-label Classification", 2009.

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ Changelog
3131
New features
3232
............
3333

34+
- Added :class:`multioutput.ClassifierChain` for multi-label
35+
classification. By `Adam Kleczewski <adamklec>`_.
36+
3437
- Validation that input data contains no NaN or inf can now be suppressed
3538
using :func:`config_context`, at your own risk. This will save on runtime,
3639
and may be particularly useful for prediction time. :issue:`7548` by

examples/multioutput/README.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.. _multioutput_examples:
2+
3+
Multioutput methods
4+
----------------
5+
6+
Examples concerning the :mod:`sklearn.multioutput` module.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
============================
3+
Classifier Chain
4+
============================
5+
Example of using classifier chain on a multilabel dataset.
6+
7+
For this example we will use the `yeast
8+
http://mldata.org/repository/data/viewslug/yeast/`_ dataset which
9+
contains 2417 datapoints each with 103 features and 14 possible labels. Each
10+
datapoint has at least one label. As a baseline we first train a logistic
11+
regression classifier for each of the 14 labels. To evaluate the performance
12+
of these classifiers we predict on a held-out test set and calculate the
13+
:ref:`User Guide <jaccard_similarity_score>`.
14+
15+
Next we create 10 classifier chains. Each classifier chain contains a
16+
logistic regression model for each of the 14 labels. The models in each
17+
chain are ordered randomly. In addition to the 103 features in the dataset,
18+
each model gets the predictions of the preceding models in the chain as
19+
features (note that by default at training time each model gets the true
20+
labels as features). These additional features allow each chain to exploit
21+
correlations among the classes. The Jaccard similarity score for each chain
22+
tends to be greater than that of the set independent logistic models.
23+
24+
Because the models in each chain are arranged randomly there is significant
25+
variation in performance among the chains. Presumably there is an optimal
26+
ordering of the classes in a chain that will yield the best performance.
27+ 67F4
However we do not know that ordering a priori. Instead we can construct an
28+
voting ensemble of classifier chains by averaging the binary predictions of
29+
the chains and apply a threshold of 0.5. The Jaccard similarity score of the
30+
ensemble is greater than that of the independent models and tends to exceed
31+
the score of each chain in the ensemble (although this is not guaranteed
32+
with randomly ordered chains).
33+
"""
34+
35+
print(__doc__)
36+
37+
# Author: Adam Kleczewski
38+
# License: BSD 3 clause
39+
40+
import numpy as np
41+
import matplotlib.pyplot as plt
42+
from sklearn.multioutput import ClassifierChain
43+
from sklearn.model_selection import train_test_split
44+
from sklearn.multiclass import OneVsRestClassifier
45+
from sklearn.metrics import jaccard_similarity_score
46+
from sklearn.linear_model import LogisticRegression
47+
from sklearn.datasets import fetch_mldata
48+
49+
# Load a multi-label dataset
50+
yeast = fetch_mldata('yeast')
51+
X = yeast['data']
52+
Y = yeast['target'].transpose().toarray()
53+
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2,
54+
random_state=0)
55+
56+
# Fit an independent logistic regression model for each class using the
57+
# OneVsRestClassifier wrapper.
58+
ovr = OneVsRestClassifier(LogisticRegression())
59+
ovr.fit(X_train, Y_train)
60+
Y_pred_ovr = ovr.predict(X_test)
61+
ovr_jaccard_score = jaccard_similarity_score(Y_test, Y_pred_ovr)
62+
63+
# Fit an ensemble of logistic regression classifier chains and take the
64+
# take the average prediction of all the chains.
65+
chains = [ClassifierChain(LogisticRegression(), order='random', random_state=i)
66+
for i in range(10)]
67+
for chain in chains:
68+
chain.fit(X_train, Y_train)
69+
70+
Y_pred_chains = np.array([chain.predict(X_test) for chain in
71+
chains])
72+
chain_jaccard_scores = [jaccard_similarity_score(Y_test, Y_pred_chain >= .5)
73+
for Y_pred_chain in Y_pred_chains]
74+
75+
Y_pred_ensemble = Y_pred_chains.mean(axis=0)
76+
ensemble_jaccard_score = jaccard_similarity_score(Y_test,
77+
Y_pred_ensemble >= .5)
78+
79+
model_scores = [ovr_jaccard_score] + chain_jaccard_scores
80+
model_scores.append(ensemble_jaccard_score)
81+
82+
model_names = ('Independent Models',
83+
'Chain 1',
84+
'Chain 2',
85+
'Chain 3',
86+
'Chain 4',
87+
'Chain 5',
88+
'Chain 6',
89+
'Chain 7',
90+
'Chain 8',
91+
'Chain 9',
92+
'Chain 10',
93+
'Ensemble Average')
94+
95+
y_pos = np.arange(len(model_names))
96+
y_pos[1:] += 1
97+
y_pos[-1] += 1
98+
99+
# Plot the Jaccard similarity scores for the independent model, each of the
100+
# chains, and the ensemble (note that the vertical axis on this plot does
101+
# not begin at 0).
102+
103+
fig = plt.figure(figsize=(7, 4))
104+
plt.title('Classifier Chain Ensemble')
105+
plt.xticks(y_pos, model_names, rotation='vertical')
106+
plt.ylabel('Jaccard Similarity Score')
107+
plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1])
108+
colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g']
109+
plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors)
110+
plt.show()

0 commit comments

Comments
 (0)
0