diff --git a/sklearn/metrics/_plot/confusion_matrix.py b/sklearn/metrics/_plot/confusion_matrix.py index c858ac3950f86..f3c380962f9f4 100644 --- a/sklearn/metrics/_plot/confusion_matrix.py +++ b/sklearn/metrics/_plot/confusion_matrix.py @@ -40,6 +40,24 @@ class ConfusionMatrixDisplay: figure_ : matplotlib Figure Figure containing the confusion matrix. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.svm import SVC + >>> X, y = make_classification(random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, + ... random_state=0) + >>> clf = SVC(random_state=0) + >>> clf.fit(X_train, y_train) + SVC(random_state=0) + >>> predictions = clf.predict(X_test) + >>> cm = confusion_matrix(y_test, predictions, labels=clf.classes_) + >>> disp = ConfusionMatrixDisplay(confusion_matrix=cm, + ... display_labels=clf.classes_) + >>> disp.plot() # doctest: +SKIP """ def __init__(self, confusion_matrix, *, display_labels=None): self.confusion_matrix = confusion_matrix diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index bb2a91c198c41..7a04ce02f3ef6 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -40,6 +40,24 @@ class PrecisionRecallDisplay: figure_ : matplotlib Figure Figure containing the curve. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import (precision_recall_curve, + ... PrecisionRecallDisplay) + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.svm import SVC + >>> X, y = make_classification(random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, + ... random_state=0) + >>> clf = SVC(random_state=0) + >>> clf.fit(X_train, y_train) + SVC(random_state=0) + >>> predictions = clf.predict(X_test) + >>> precision, recall, _ = precision_recall_curve(y_test, predictions) + >>> disp = PrecisionRecallDisplay(precision=precision, recall=recall) + >>> disp.plot() # doctest: +SKIP """ def __init__(self, precision, recall, *, average_precision=None, estimator_name=None):