8000 [MRG+1] Classifier chain (#7602) · dmohns/scikit-learn@76cb8aa · GitHub
[go: up one dir, main page]

8000
Skip to content
.hEHvLI{min-width:0;-webkit-align-items:center;-webkit-box-align:center;-ms-flex-align:center;align-items:center;}/*!sc*/ .bmcJak{min-width:0;}/*!sc*/ .fyKNMY[data-size="medium"]{color:var(--fgColor-default,var(--color-fg-default,#1F2328));}/*!sc*/ .gUkoLg{-webkit-box-pack:center;-webkit-justify-content:center;-ms-flex-pack:center;justify-content:center;}/*!sc*/ .PhXDz{font-weight:600;color:var(--fgColor-default,var(--color-fg-default,#1F2328));}/*!sc*/ .gLSgdJ{font-weight:600;color:var(--fgColor-default,var(--color-fg-default,#1F2328));}/*!sc*/ .gLSgdJ:hover{color:var(--fgColor-default,var(--color-fg-default,#1F2328));}/*!sc*/ .irPhWZ{width:60px;}/*!sc*/ .dNbsEP{width:62px;}/*!sc*/ .kHfwUD{width:60px;height:22px;}/*!sc*/ .bHLmSv{position:absolute;inset:0 -2px;cursor:col-resize;background-color:transparent;-webkit-transition-delay:0.1s;transition-delay:0.1s;}/*!sc*/ .bHLmSv:hover{background-color:var(--bgColor-neutral-muted,var(--color-neutral-muted,rgba(175,184,193,0.2)));}/*!sc*/ data-styled.g1[id="Box-sc-g0xbh4-0"]{content:"hEHvLI,bmcJak,fyKNMY,gUkoLg,PhXDz,gLSgdJ,irPhWZ,dNbsEP,kHfwUD,bHLmSv,"}/*!sc*/ .jjwhNb{position:relative;display:inline-block;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;}/*!sc*/ .jjwhNb::after{position:absolute;z-index:1000000;display:none;padding:0.5em 0.75em;font:normal normal 11px/1.5 -apple-system,BlinkMacSystemFont,"Segoe UI","Noto Sans",Helvetica,Arial,sans-serif,"Apple Color Emoji","Segoe UI Emoji";-webkit-font-smoothing:subpixel-antialiased;color:var(--tooltip-fgColor,var(--fgColor-onEmphasis,var(--color-fg-on-emphasis,#ffffff)));text-align:center;-webkit-text-decoration:none;text-decoration:none;text-shadow:none;text-transform:none;-webkit-letter-spacing:normal;-moz-letter-spacing:normal;-ms-letter-spacing:normal;letter-spacing:normal;word-wrap:break-word;white-space:pre;pointer-events:none;content:attr(aria-label);background:var(--tooltip-bgColor,var(--bgColor-emphasis,var(--color-neutral-emphasis-plus,#24292f)));border-radius:6px;opacity:0;}/*!sc*/ @-webkit-keyframes tooltip-appear{from{opacity:0;}to{opacity:1;}}/*!sc*/ @keyframes tooltip-appear{from{opacity:0;}to{opacity:1;}}/*!sc*/ .jjwhNb:hover::after,.jjwhNb:active::after,.jjwhNb:focus::after,.jjwhNb:focus-within::after{display:inline-block;-webkit-text-decoration:none;text-decoration:none;-webkit-animation-name:tooltip-appear;animation-name:tooltip-appear;-webkit-animation-duration:0.1s;animation-duration:0.1s;-webkit-animation-fill-mode:forwards;animation-fill-mode:forwards;-webkit-animation-timing-function:ease-in;animation-timing-function:ease-in;-webkit-animation-delay:0s;animation-delay:0s;}/*!sc*/ .jjwhNb.tooltipped-no-delay:hover::after,.jjwhNb.tooltipped-no-delay:active::after,.jjwhNb.tooltipped-no-delay:focus::after,.jjwhNb.tooltipped-no-delay:focus-within::after{-webkit-animation-delay:0s;animation-delay:0s;}/*!sc*/ .jjwhNb.tooltipped-multiline:hover::after,.jjwhNb.tooltipped-multiline:active::after,.jjwhNb.tooltipped-multiline:focus::after,.jjwhNb.tooltipped-multiline:focus-within::after{display:table-cell;}/*!sc*/ .jjwhNb.tooltipped-s::after,.jjwhNb.tooltipped-se::after,.jjwhNb.tooltipped-sw::after{top:100%;right:50%;margin-top:6px;}/*!sc*/ .jjwhNb.tooltipped-se::after{right:auto;left:50%;margin-left:-16px;}/*!sc*/ .jjwhNb.tooltipped-sw::after{margin-right:-16px;}/*!sc*/ .jjwhNb.tooltipped-n::after,.jjwhNb.tooltipped-ne::after,.jjwhNb.tooltipped-nw::after{right:50%;bottom:100%;margin-bottom:6px;}/*!sc*/ .jjwhNb.tooltipped-ne::after{right:auto;left:50%;margin-left:-16px;}/*!sc*/ .jjwhNb.tooltipped-nw::after{margin-right:-16px;}/*!sc*/ .jjwhNb.tooltipped-s::after,.jjwhNb.tooltipped-n::after{-webkit-transform:translateX(50%);-ms-transform:translateX(50%);transform:translateX(50%);}/*!sc*/ .jjwhNb.tooltipped-w::after{right:100%;bottom:50%;margin-right:6px;-webkit-transform:translateY(50%);-ms-transform:translateY(50%);transform:translateY(50%);}/*!sc*/ .jjwhNb.tooltipped-e::after{bottom:50%;left:100%;margin-left:6px;-webkit-transform:translateY(50%);-ms-transform:translateY(50%);transform:translateY(50%);}/*!sc*/ .jjwhNb.tooltipped-multiline::after{width:-webkit-max-content;width:-moz-max-content;width:max-content;max-width:250px;word-wrap:break-word;white-space:pre-line;border-collapse:separate;}/*!sc*/ .jjwhNb.tooltipped-multiline.tooltipped-s::after,.jjwhNb.tooltipped-multiline.tooltipped-n::after{right:auto;left:50%;-webkit-transform:translateX(-50%);-ms-transform:translateX(-50%);transform:translateX(-50%);}/*!sc*/ .jjwhNb.tooltipped-multiline.tooltipped-w::after,.jjwhNb.tooltipped-multiline.tooltipped-e::after{right:100%;}/*!sc*/ .jjwhNb.tooltipped-align-right-2::after{right:0;margin-right:0;}/*!sc*/ .jjwhNb.tooltipped-align-left-2::after{left:0;margin-left:0;}/*!sc*/ data-styled.g4[id="Tooltip__TooltipBase-sc-17tf59c-0"]{content:"jjwhNb,"}/*!sc*/ .irithh{position:relative;overflow:hidden;-webkit-mask-image:radial-gradient(white,black);mask-image:radial-gradient(white,black);background-color:var(--bgColor-neutral-muted,var(--color-neutral-subtle,rgba(234,238,242,0.5)));border-radius:3px;display:block;height:1.2em;width:60px;}/*!sc*/ .irithh::after{-webkit-animation:crVFvv 1.5s infinite linear;animation:crVFvv 1.5s infinite linear;background:linear-gradient(90deg,transparent,var(--bgColor-neutral-muted,var(--color-neutral-subtle,rgba(234,238,242,0.5))),transparent);content:'';position:absolute;-webkit-transform:translateX(-100%);-ms-transform:translateX(-100%);transform:translateX(-100%);bottom:0;left:0;right:0;top:0;}/*!sc*/ .ihfxfT{position:relative;overflow:hidden;-webkit-mask-image:radial-gradient(white,black);mask-image:radial-gradient(white,black);background-color:var(--bgColor-neutral-muted,var(--color-neutral-subtle,rgba(234,238,242,0.5)));border-radius:3px;display:block;height:1.2em;width:62px;}/*!sc*/ .ihfxfT::after{-webkit-animation:crVFvv 1.5s infinite linear;animation:crVFvv 1.5s infinite linear;background:linear-gradient(90deg,transparent,var(--bgColor-neutral-muted,var(--color-neutral-subtle,rgba(234,238,242,0.5))),transparent);content:'';position:absolute;-webkit-transform:translateX(-100%);-ms-transform:translateX(-100%);transform:translateX(-100%);bottom:0;left:0;right:0;top:0;}/*!sc*/ .kRBfod{position:relative;overflow:hidden;-webkit-mask-image:radial-gradient(white,black);mask-image:radial-gradient(white,black);background-color:var(--bgColor-neutral-muted,var(--color-neutral-subtle,rgba(234,238,242,0.5)));border-radius:3px;display:block;height:1.2em;width:60px;height:22px;}/*!sc*/ .kRBfod::after{-webkit-animation:crVFvv 1.5s infinite linear;animation:crVFvv 1.5s infinite linear;background:linear-gradient(90deg,transparent,var(--bgColor-neutral-muted,var(--color-neutral-subtle,rgba(234,238,242,0.5))),transparent);content:'';position:absolute;-webkit-transform:translateX(-100%);-ms-transform:translateX(-100%);transform:translateX(-100%);bottom:0;left:0;right:0;top:0;}/*!sc*/ data-styled.g22[id="LoadingSkeleton-sc-695d630a-0"]{content:"irithh,ihfxfT,kRBfod,"}/*!sc*/ @-webkit-keyframes crVFvv{0%{-webkit-transform:translateX(-100%);-ms-transform:translateX(-100%);transform:translateX(-100%);}50%{-webkit-transform:translateX(100%);-ms-transform:translateX(100%);transform:translateX(100%);}100%{-webkit-transform:translateX(100%);-ms-transform:translateX(100%);transform:translateX(100%);}}/*!sc*/ @keyframes crVFvv{0%{-webkit-transform:translateX(-100%);-ms-transform:translateX(-100%);transform:translateX(-100%);}50%{-webkit-transform:translateX(100%);-ms-transform:translateX(100%);transform:translateX(100%);}100%{-webkit-transform:translateX(100%);-ms-transform:translateX(100%);transform:translateX(100%);}}/*!sc*/ data-styled.g43[id="sc-keyframes-crVFvv"]{content:"crVFvv,"}/*!sc*/

Commit 76cb8aa

Browse files
Adam Kleczewskidmohns
authored andcommitted
[MRG+1] Classifier chain (scikit-learn#7602)
[MRG+2] Classifier chain
1 parent 6d2a026 commit 76cb8aa

File tree

8 files changed

+549
-11
lines changed

8 files changed

+549
-11
lines changed

doc/modules/multiclass.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,30 @@ Below is an example of multioutput classification:
348348
[0, 0, 2],
349349
[2, 0, 0]])
350350

351+
Classifier Chain
352+
================
353+
354+
Classifier chains (see :class:`ClassifierChain`) are a way of combining a
355+
number of binary classifiers into a single multi-label model that is capable
356+
of exploiting correlations among targets.
357+
358+
For a multi-label classification problem with N classes, N binary
359+
classifiers are assigned an integer between 0 and N-1. These integers
360+
define the order of models in the chain. Each classifier is then fit on the
361+
available training data plus the true labels of the classes whose
362+
models were assigned a lower number.
363+
364+
When predicting, the true labels will not be available. Instead the
365+
predictions of each model are passed on to the subsequent models in the
366+
chain to be used as features.
367+
368+
Clearly the order of the chain is important. The first model in the chain
369+
has no information about the other labels while the last model in the chain
370+
has features indicating the presence of all of the other labels. In general
371+
one does not know the optimal ordering of the models in the chain so
372+
typically many randomly ordered chains are fit and their predictions are
373+
averaged together.
374+
375+
.. topic:: References:
376+
Jesse Read, Bernhard Pfahringer, Geoff Holmes, Eibe Frank,
377+
"Classifier Chains for Multi-label Classification", 2009.

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ Changelog
3131
New features
3232
............
3333

34+
- Added :class:`multioutput.ClassifierChain` for multi-label
35+
classification. By `Adam Kleczewski <adamklec>`_.
36+
3437
- Validation that input data contains no NaN or inf can now be suppressed
3538
using :func:`config_context`, at your own risk. This will save on runtime,
3639
and may be particularly useful for prediction time. :issue:`7548` by

examples/multioutput/README.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.. _multioutput_examples:
2+
3+
Multioutput methods
4+
----------------
5+
6+
Examples concerning the :mod:`sklearn.multioutput` module.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
============================
3+
Classifier Chain
4+
============================
5+
Example of using classifier chain on a multilabel dataset.
6+
7+
For this example we will use the `yeast
8+
http://mldata.org/repository/data/viewslug/yeast/`_ dataset which
9+
contains 2417 datapoints each with 103 features and 14 possible labels. Each
10+
datapoint has at least one label. As a baseline we first train a logistic
11+
regression classifier for each of the 14 labels. To evaluate the performance
12+
of these classifiers we predict on a held-out test set and calculate the
13+
:ref:`User Guide <jaccard_similarity_score>`.
14+
15+
Next we create 10 classifier chains. Each classifier chain contains a
16+
logistic regression model for each of the 14 labels. The models in each
17+
chain are ordered randomly. In addition to the 103 features in the dataset,
18+
each model gets the predictions of the preceding models in the chain as
19+
features (note that by default at training time each model gets the true
20+
labels as features). These additional features allow each chain to exploit
21+
correlations among the classes. The Jaccard similarity score for each chain
22+
tends to be greater than that of the set independent logistic models.
23+
24+
Because the models in each chain are arranged randomly there is significant
25+
variation in performance among the chains. Presumably there is an optimal
26+
ordering of the classes in a chain that will yield the best performance.
27+
However we do not know that ordering a priori. Instead we can construct an
28+
voting ensemble of classifier chains by averaging the binary predictions of
29+
the chains and apply a threshold of 0.5. The Jaccard similarity score of the
30+
ensemble is greater than that of the independent models and tends to exceed
31+
the score of each chain in the ensemble (although this is not guaranteed
32+
with randomly ordered chains).
33+
"""
34+
35+
print(__doc__)
36+
37+
# Author: Adam Kleczewski
38+
# License: BSD 3 clause
39+
40+
import numpy as np
41+
import matplotlib.pyplot as plt
42+
from sklearn.multioutput import ClassifierChain
43+
from sklearn.model_selection import train_test_split
44+
from sklearn.multiclass import OneVsRestClassifier
45+
from sklearn.metrics import jaccard_similarity_score
46+
from sklearn.linear_model import LogisticRegression
47+
from sklearn.datasets import fetch_mldata
48+
49+
# Load a multi-label dataset
50+
yeast = fetch_mldata('yeast')
51+
X = yeast['data']
52+
Y = yeast['target'].transpose().toarray()
53+
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2,
54+
random_state=0)
55+
56+
# Fit an independent logistic regression model for each class using the
57+
# OneVsRestClassifier wrapper.
58+
ovr = OneVsRestClassifier(LogisticRegression())
59+
ovr.fit(X_train, Y_train)
60+
Y_pred_ovr = ovr.predict(X_test)
61+
ovr_jaccard_score = jaccard_similarity_score(Y_test, Y_pred_ovr)
62+
63+
# Fit an ensemble of logistic regression classifier chains and take the
64+
# take the average prediction of all the chains.
65+
chains = [ClassifierChain(LogisticRegression(), order='random', random_state=i)
66+
for i in range(10)]
67+
for chain in chains:
68+
chain.fit(X_train, Y_train)
69+
70+
Y_pred_chains = np.array([chain.predict(X_test) for chain in
71+
chains])
72+
chain_jaccard_scores = [jaccard_similarity_score(Y_test, Y_pred_chain >= .5)
73+
for Y_pred_chain in Y_pred_chains]
74+
75+
Y_pred_ensemble = Y_pred_chains.mean(axis=0)
76+
ensemble_jaccard_score = jaccard_similarity_score(Y_test,
77+
Y_pred_ensemble >= .5)
78+
79+
model_scores = [ovr_jaccard_score] + chain_jaccard_scores
80+
model_scores.append(ensemble_jaccard_score)
81+
82+
model_names = ('Independent Models',
83+
'Chain 1',
84+
'Chain 2',
85+
'Chain 3',
86+
'Chain 4',
87+
'Chain 5',
88+
'Chain 6',
89+
'Chain 7',
90+
'Chain 8',
91+
'Chain 9',
92+
'Chain 10',
93+
'Ensemble Average')
94+
95+
y_pos = np.arange(len(model_names))
96+
y_pos[1:] += 1
97+
y_pos[-1] += 1
98+
99+
# Plot the Jaccard similarity scores for the independent model, each of the
100+
# chains, and the ensemble (note that the vertical axis on this plot does
101+
# not begin at 0).
102+
103+
fig = plt.figure(figsize=(7, 4))
104+
plt.title('Classifier Chain Ensemble')
105+
plt.xticks(y_pos, model_names, rotation='vertical')
106+
plt.ylabel('Jaccard Similarity Score')
107+
plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1])
108+
colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g']
109+
plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors)
110+
plt.show()

0 commit comments

Comments
 (0)
0