diff --git a/sklearn/inspection/tests/test_partial_dependence.py b/sklearn/inspection/tests/test_partial_dependence.py index 406a6b1bf2e42..d2d3c7818e448 100644 --- a/sklearn/inspection/tests/test_partial_dependence.py +++ b/sklearn/inspection/tests/test_partial_dependence.py @@ -28,7 +28,6 @@ from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import if_matplotlib -from sklearn.utils.testing import close_figure # toy sample @@ -437,7 +436,7 @@ def test_plot_partial_dependence(): assert len(axs) == 3 assert all(ax.has_data for ax in axs) - close_figure() + plt.close('all') @if_matplotlib @@ -471,7 +470,7 @@ def test_plot_partial_dependence_multiclass(): assert len(axs) == 2 assert all(ax.has_data for ax in axs) - close_figure() + plt.close('all') @if_matplotlib @@ -499,7 +498,7 @@ def test_plot_partial_dependence_multioutput(): assert len(axs) == 2 assert all(ax.has_data for ax in axs) - close_figure() + plt.close('all') @if_matplotlib @@ -533,13 +532,14 @@ def test_plot_partial_dependence_multioutput(): @pytest.mark.filterwarnings('ignore:Default solver will be changed ') # 0.22 @pytest.mark.filterwarnings('ignore:Default multi_class will be') # 0.22 def test_plot_partial_dependence_error(data, params, err_msg): + import matplotlib.pyplot as plt # noqa X, y = data estimator = LinearRegression().fit(X, y) with pytest.raises(ValueError, match=err_msg): plot_partial_dependence(estimator, X, **params) - close_figure() + plt.close() @if_matplotlib @@ -559,4 +559,4 @@ def test_plot_partial_dependence_fig(): assert plt.gcf() is fig - close_figure() + plt.close() diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 33b0da90fa9db..1662294189690 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -970,21 +970,3 @@ def check_docstring_parameters(func, doc=None, ignore=None, class_name=None): if n1 != n2: incorrect += [func_name + ' ' + n1 + ' != ' + n2] return incorrect - - -def close_figure(fig=None): - """Close a matplotlibt figure. - - Parameters - ---------- - fig : int or str or Figure, optional (default=None) - The figure, figure number or figure name to close. If ``None``, all - current figures are closed. - """ - from matplotlib.pyplot import get_fignums, close as _close # noqa - - if fig is None: - for fig in get_fignums(): - _close(fig) - else: - _close(fig)