|
| 1 | +""" |
| 2 | +============================ |
| 3 | +Classifier Chain |
| 4 | +============================ |
| 5 | +Example of using classifier chain on a multilabel dataset. |
8000
tr>
| 6 | +
|
| 7 | +For this example we will use the `yeast |
| 8 | +http://mldata.org/repository/data/viewslug/yeast/`_ dataset which |
| 9 | +contains 2417 datapoints each with 103 features and 14 possible labels. Each |
| 10 | +datapoint has at least one label. As a baseline we first train a logistic |
| 11 | +regression classifier for each of the 14 labels. To evaluate the performance |
| 12 | +of these classifiers we predict on a held-out test set and calculate the |
| 13 | +:ref:`User Guide <jaccard_similarity_score>`. |
| 14 | +
|
| 15 | +Next we create 10 classifier chains. Each classifier chain contains a |
| 16 | +logistic regression model for each of the 14 labels. The models in each |
| 17 | +chain are ordered randomly. In addition to the 103 features in the dataset, |
| 18 | +each model gets the predictions of the preceding models in the chain as |
| 19 | +features (note that by default at training time each model gets the true |
| 20 | +labels as features). These additional features allow each chain to exploit |
| 21 | +correlations among the classes. The Jaccard similarity score for each chain |
| 22 | +tends to be greater than that of the set independent logistic models. |
| 23 | +
|
| 24 | +Because the models in each chain are arranged randomly there is significant |
| 25 | +variation in performance among the chains. Presumably there is an optimal |
| 26 | +ordering of the classes in a chain that will yield the best performance. |
| 27 | +However we do not know that ordering a priori. Instead we can construct an |
| 28 | +voting ensemble of classifier chains by averaging the binary predictions of |
| 29 | +the chains and apply a threshold of 0.5. The Jaccard similarity score of the |
| 30 | +ensemble is greater than that of the independent models and tends to exceed |
| 31 | +the score of each chain in the ensemble (although this is not guaranteed |
| 32 | +with randomly ordered chains). |
| 33 | +""" |
| 34 | + |
| 35 | +print(__doc__) |
| 36 | + |
| 37 | +# Author: Adam Kleczewski |
| 38 | +# License: BSD 3 clause |
| 39 | + |
| 40 | +import numpy as np |
| 41 | +import matplotlib.pyplot as plt |
| 42 | +from sklearn.multioutput import ClassifierChain |
| 43 | +from sklearn.model_selection import train_test_split |
| 44 | +from sklearn.multiclass import OneVsRestClassifier |
| 45 | +from sklearn.metrics import jaccard_similarity_score |
| 46 | +from sklearn.linear_model import LogisticRegression |
| 47 | +from sklearn.datasets import fetch_mldata |
| 48 | + |
| 49 | +# Load a multi-label dataset |
| 50 | +yeast = fetch_mldata('yeast') |
| 51 | +X = yeast['data'] |
| 52 | +Y = yeast['target'].transpose().toarray() |
| 53 | +X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2, |
| 54 | + random_state=0) |
| 55 | + |
| 56 | +# Fit an independent logistic regression model for each class using the |
| 57 | +# OneVsRestClassifier wrapper. |
| 58 | +ovr = OneVsRestClassifier(LogisticRegression()) |
| 59 | +ovr.fit(X_train, Y_train) |
| 60 | +Y_pred_ovr = ovr.predict(X_test) |
| 61 | +ovr_jaccard_score = jaccard_similarity_score(Y_test, Y_pred_ovr) |
| 62 | + |
| 63 | +# Fit an ensemble of logistic regression classifier chains and take the |
| 64 | +# take the average prediction of all the chains. |
| 65 | +chains = [ClassifierChain(LogisticRegression(), order='random', random_state=i) |
| 66 | + for i in range(10)] |
| 67 | +for chain in chains: |
| 68 | + chain.fit(X_train, Y_train) |
| 69 | + |
| 70 | +Y_pred_chains = np.array([chain.predict(X_test) for chain in |
| 71 | + chains]) |
| 72 | +chain_jaccard_scores = [jaccard_similarity_score(Y_test, Y_pred_chain >= .5) |
| 73 | + for Y_pred_chain in Y_pred_chains] |
| 74 | + |
| 75 | +Y_pred_ensemble = Y_pred_chains.mean(axis=0) |
| 76 | +ensemble_jaccard_score = jaccard_similarity_score(Y_test, |
| 77 | + Y_pred_ensemble >= .5) |
| 78 | + |
| 79 | +model_scores = [ovr_jaccard_score] + chain_jaccard_scores |
| 80 | +model_scores.append(ensemble_jaccard_score) |
| 81 | + |
| 82 | +model_names = ('Independent Models', |
| 83 | + 'Chain 1', |
| 84 | + 'Chain 2', |
| 85 | + 'Chain 3', |
| 86 | + 'Chain 4', |
| 87 | + 'Chain 5', |
| 88 | + 'Chain 6', |
| 89 | + 'Chain 7', |
| 90 | + 'Chain 8', |
| 91 | + 'Chain 9', |
| 92 | + 'Chain 10', |
| 93 | + 'Ensemble Average') |
| 94 | + |
| 95 | +y_pos = np.arange(len(model_names)) |
| 96 | +y_pos[1:] += 1 |
| 97 | +y_pos[-1] += 1 |
| 98 | + |
| 99 | +# Plot the Jaccard similarity scores for the independent model, each of the |
| 100 | +# chains, and the ensemble (note that the vertical axis on this plot does |
| 101 | +# not begin at 0). |
| 102 | + |
| 103 | +fig = plt.figure(figsize=(7, 4)) |
| 104 | +plt.title('Classifier Chain Ensemble') |
| 105 | +plt.xticks(y_pos, model_names, rotation='vertical') |
| 106 | +plt.ylabel('Jaccard Similarity Score') |
| 107 | +plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1]) |
| 108 | +colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g'] |
| 109 | +plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors) |
| 110 | +plt.show() |
0 commit comments