|
5 | 5 | Example of using classifier chain on a multilabel dataset. |
6 | 6 |
|
7 | 7 | For this example we will use the `yeast |
8 | | -<http://mldata.org/repository/data/viewslug/yeast>`_ dataset which |
9 | | -contains 2417 datapoints each with 103 features and 14 possible labels. Each |
10 | | -datapoint has at least one label. As a baseline we first train a logistic |
11 | | -regression classifier for each of the 14 labels. To evaluate the performance |
12 | | -of these classifiers we predict on a held-out test set and calculate the |
13 | | -:ref:`User Guide <jaccard_similarity_score>`. |
| 8 | +<http://mldata.org/repository/data/viewslug/yeast>`_ dataset which contains |
| 9 | +2417 datapoints each with 103 features and 14 possible labels. Each |
| 10 | +data point has at least one label. As a baseline we first train a logistic |
| 11 | +regression classifier for each of the 14 labels. To evaluate the performance of |
| 12 | +these classifiers we predict on a held-out test set and calculate the |
| 13 | +:ref:`jaccard similarity score <jaccard_similarity_score>`. |
14 | 14 |
|
15 | 15 | Next we create 10 classifier chains. Each classifier chain contains a |
16 | 16 | logistic regression model for each of the 14 labels. The models in each |
|
79 | 79 | model_scores = [ovr_jaccard_score] + chain_jaccard_scores |
80 | 80 | model_scores.append(ensemble_jaccard_score) |
81 | 81 |
|
82 | | -model_names = ('Independent Models', |
| 82 | +model_names = ('Independent', |
83 | 83 | 'Chain 1', |
84 | 84 | 'Chain 2', |
85 | 85 | 'Chain 3', |
|
90 | 90 | 'Chain 8', |
91 | 91 | 'Chain 9', |
92 | 92 | 'Chain 10', |
93 | | - 'Ensemble Average') |
| 93 | + 'Ensemble') |
94 | 94 |
|
95 | | -y_pos = np.arange(len(model_names)) |
96 | | -y_pos[1:] += 1 |
97 | | -y_pos[-1] += 1 |
| 95 | +x_pos = np.arange(len(model_names)) |
98 | 96 |
|
99 | 97 | # Plot the Jaccard similarity scores for the independent model, each of the |
100 | 98 | # chains, and the ensemble (note that the vertical axis on this plot does |
101 | 99 | # not begin at 0). |
102 | 100 |
|
103 | | -fig = plt.figure(figsize=(7, 4)) |
104 | | -plt.title('Classifier Chain Ensemble') |
105 | | -plt.xticks(y_pos, model_names, rotation='vertical') |
106 | | -plt.ylabel('Jaccard Similarity Score') |
107 | | -plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1]) |
| 101 | +fig, ax = plt.subplots(figsize=(7, 4)) |
| 102 | +ax.grid(True) |
| 103 | +ax.set_title('Classifier Chain Ensemble Performance Comparison') |
| 104 | +ax.set_xticks(x_pos) |
| 105 | +ax.set_xticklabels(model_names, rotation='vertical') |
| 106 | +ax.set_ylabel('Jaccard Similarity Score') |
| 107 | +ax.set_ylim([min(model_scores) * .9, max(model_scores) * 1.1]) |
108 | 108 | colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g'] |
109 | | -plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors) |
| 109 | +ax.bar(x_pos, model_scores, alpha=0.5, color=colors) |
| 110 | +plt.tight_layout() |
110 | 111 | plt.show() |
0 commit comments