|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +from ..base import BaseEstimator |
| 4 | +from ..utils.validation import check_X_y, check_array, check_is_fitted |
| 5 | +from ..utils import safe_mask |
| 6 | + |
| 7 | + |
| 8 | +def _check_estimator(estimator): |
| 9 | + """Make sure that an estimator implements the necessary methods.""" |
| 10 | + if not hasattr(estimator, "predict_proba"): |
| 11 | + raise ValueError("The base estimator should implement predict_proba!") |
| 12 | + |
| 13 | + |
| 14 | +class SelfTraining(BaseEstimator): |
| 15 | + |
| 16 | + """Self-Training classifier |
| 17 | +
|
| 18 | + Parameters |
| 19 | + ---------- |
| 20 | + estimator : estimator object |
| 21 | + An estimator object implementing `fit` and `predict_proba`. |
| 22 | +
|
| 23 | + threshold : float |
| 24 | + Threshold above which predictions are added to the labeled dataset |
| 25 | +
|
| 26 | + max_iter : integer |
| 27 | + Change maximum number of iterations allowed |
| 28 | +
|
| 29 | + """ |
| 30 | + def __init__(self, estimator, threshold=0.7, max_iter=500): |
| 31 | + self.estimator = estimator |
| 32 | + self.threshold = threshold |
| 33 | + self.max_iter = max_iter |
| 34 | + |
| 35 | + def fit(self, X, y): |
| 36 | + """ |
| 37 | + Fits SelfTraining Estimator to dataset |
| 38 | +
|
| 39 | + Parameters |
| 40 | + ---------- |
| 41 | + X : array-like, shape = (n_samples, n_features) |
| 42 | + array representing the data |
| 43 | + y : array-like, shape = (n_samples, 1) |
| 44 | + array representing the labels |
| 45 | +
|
| 46 | + Returns |
| 47 | + ------- |
| 48 | + self: returns an instance of self. |
| 49 | + """ |
| 50 | + X, y = check_X_y(X, y) |
| 51 | + _check_estimator(self.estimator) |
| 52 | + |
| 53 | + # Data usable for supervised training |
| 54 | + X_labeled = X[safe_mask(X, np.where(y != -1))][0] |
| 55 | + y_labeled = y[safe_mask(y, np.where(y != -1))][0] |
| 56 | + |
| 57 | + # Unlabeled data |
| 58 | + X_unlabeled = X[safe_mask(X, np.where(y == -1))][0] |
| 59 | + y_unlabeled = y[safe_mask(y, np.where(y == -1))][0] |
| 60 | + |
| 61 | + iter = 0 |
| 62 | + while (len(X_labeled) < len(X) and iter < self.max_iter): |
| 63 | + iter += 1 |
| 64 | + self.estimator.fit(X_labeled, y_labeled) |
| 65 | + |
| 66 | + # Select prediction where confidence is above the threshold |
| 67 | + pred = self.predict(X_unlabeled) |
| 68 | + max_proba = np.max(self.predict_proba(X_unlabeled), axis=1) |
| 69 | + confident = np.where(max_proba > self.threshold)[0] |
| 70 | + |
| 71 | + # Add newly labeled confident predictions to the dataset |
| 72 | + X_labeled = np.append(X_labeled, X_unlabeled[confident], axis=0) |
| 73 | + y_labeled = np.append(y_labeled, pred[confident], axis=0) |
| 74 | + |
| 75 | + # Remove already labeled data from unlabeled dataset |
| 76 | + X_unlabeled = np.delete(X_unlabeled, confident, axis=0) |
| 77 | + y_unlabeled = np.delete(y_unlabeled, confident, axis=0) |
| 78 | + |
| 79 | + self.estimator.fit(X_labeled, y_labeled) |
| 80 | + return self.estimator |
| 81 | + |
| 82 | + def predict(self, X): |
| 83 | + """Predict on a dataset. |
| 84 | +
|
| 85 | + Parameters |
| 86 | + ---------- |
| 87 | + X : array-like, shape = (n_samples, n_features) |
| 88 | + array representing the data |
| 89 | +
|
| 90 | + Returns |
| 91 | + ------- |
| 92 | + y : array-like, shape = (n_samples, 1) |
| 93 | + array with predicted labels |
| 94 | + """ |
| 95 | + check_is_fitted(self, 'estimator') |
| 96 | + X = check_array(X) |
| 97 | + return self.estimator.predict(X) |
| 98 | + |
| 99 | + def predict_proba(self, X): |
| 100 | + """Predict probability for each possible outcome. |
| 101 | +
|
| 102 | + Parameters |
| 103 | + ---------- |
| 104 | + X : array-like, shape = (n_samples, n_features) |
| 105 | + array representing the data |
| 106 | +
|
| 107 | + Returns |
| 108 | + ------- |
| 109 | + y : array-like, shape = (n_samples, n_features) |
| 110 | + array with prediction probabilities |
| 111 | + """ |
| 112 | + _check_estimator(self.estimator) |
| 113 | + check_is_fitted(self, 'estimator') |
| 114 | + return self.estimator.predict_proba(X) |
0 commit comments