1
1
import re
2
2
from inspect import signature
3
+ import pkgutil
4
+ import inspect
5
+ import importlib
3
6
from typing import Optional
4
7
5
8
import pytest
6
9
from sklearn .utils import all_estimators
10
+ import sklearn
7
11
8
12
numpydoc_validation = pytest .importorskip ("numpydoc.validate" )
9
13
14
18
"StackingRegressor" ,
15
19
]
16
20
21
+ FUNCTION_DOCSTRING_IGNORE_LIST = [
22
+ "sklearn._config.config_context" ,
23
+ "sklearn._config.get_config" ,
24
+ "sklearn.base.clone" ,
25
+ "sklearn.cluster._affinity_propagation.affinity_propagation" ,
26
+ "sklearn.cluster._agglomerative.linkage_tree" ,
27
+ "sklearn.cluster._kmeans.k_means" ,
28
+ "sklearn.cluster._kmeans.kmeans_plusplus" ,
29
+ "sklearn.cluster._mean_shift.estimate_bandwidth" ,
30
+ "sklearn.cluster._mean_shift.get_bin_seeds" ,
31
+ "sklearn.cluster._mean_shift.mean_shift" ,
32
+ "sklearn.cluster._optics.cluster_optics_dbscan" ,
33
+ "sklearn.cluster._optics.cluster_optics_xi" ,
34
+ "sklearn.cluster._optics.compute_optics_graph" ,
35
+ "sklearn.cluster._spectral.spectral_clustering" ,
36
+ "sklearn.compose._column_transformer.make_column_transformer" ,
37
+ "sklearn.covariance._empirical_covariance.empirical_covariance" ,
38
+ "sklearn.covariance._empirical_covariance.log_likelihood" ,
39
+ "sklearn.covariance._graph_lasso.graphical_lasso" ,
40
+ "sklearn.covariance._robust_covariance.fast_mcd" ,
41
+ "sklearn.covariance._shrunk_covariance.ledoit_wolf" ,
42
+ "sklearn.covariance._shrunk_covariance.ledoit_wolf_shrinkage" ,
43
+ "sklearn.covariance._shrunk_covariance.shrunk_covariance" ,
44
+ "sklearn.datasets._base.get_data_home" ,
45
+ "sklearn.datasets._base.load_boston" ,
46
+ "sklearn.datasets._base.load_breast_cancer" ,
47
+ "sklearn.datasets._base.load_diabetes" ,
48
+ "sklearn.datasets._base.load_digits" ,
49
+ "sklearn.datasets._base.load_files" ,
50
+ "sklearn.datasets._base.load_iris" ,
51
+ "sklearn.datasets._base.load_linnerud" ,
52
+ "sklearn.datasets._base.load_sample_image" ,
53
+ "sklearn.datasets._base.load_wine" ,
54
+ "sklearn.datasets._california_housing.fetch_california_housing" ,
55
+ "sklearn.datasets._covtype.fetch_covtype" ,
56
+ "sklearn.datasets._kddcup99.fetch_kddcup99" ,
57
+ "sklearn.datasets._lfw.fetch_lfw_pairs" ,
58
+ "sklearn.datasets._lfw.fetch_lfw_people" ,
59
+ "sklearn.datasets._olivetti_faces.fetch_olivetti_faces" ,
60
+ "sklearn.datasets._openml.fetch_openml" ,
61
+ "sklearn.datasets._rcv1.fetch_rcv1" ,
62
+ "sklearn.datasets._samples_generator.make_biclusters" ,
63
+ "sklearn.datasets._samples_generator.make_blobs" ,
64
+ "sklearn.datasets._samples_generator.make_checkerboard" ,
65
+ "sklearn.datasets._samples_generator.make_classification" ,
66
+ "sklearn.datasets._samples_generator.make_gaussian_quantiles" ,
67
+ "sklearn.datasets._samples_generator.make_hastie_10_2" ,
68
+ "sklearn.datasets._samples_generator.make_multilabel_classification" ,
69
+ "sklearn.datasets._samples_generator.make_regression" ,
70
+ "sklearn.datasets._samples_generator.make_sparse_coded_signal" ,
71
+ "sklearn.datasets._samples_generator.make_sparse_spd_matrix" ,
72
+ "sklearn.datasets._samples_generator.make_spd_matrix" ,
73
+ "sklearn.datasets._species_distributions.fetch_species_distributions" ,
74
+ "sklearn.datasets._svmlight_format_io.dump_svmlight_file" ,
75
+ "sklearn.datasets._svmlight_format_io.load_svmlight_file" ,
76
+ "sklearn.datasets._svmlight_format_io.load_svmlight_files" ,
77
+ "sklearn.datasets._twenty_newsgroups.fetch_20newsgroups" ,
78
+ "sklearn.decomposition._dict_learning.dict_learning" ,
79
+ "sklearn.decomposition._dict_learning.dict_learning_online" ,
80
+ "sklearn.decomposition._dict_learning.sparse_encode" ,
81
+ "sklearn.decomposition._fastica.fastica" ,
82
+ "sklearn.decomposition._nmf.non_negative_factorization" ,
83
+ "sklearn.externals._packaging.version.parse" ,
84
+ "sklearn.feature_extraction.image.extract_patches_2d" ,
85
+ "sklearn.feature_extraction.image.grid_to_graph" ,
86
+ "sklearn.feature_extraction.image.img_to_graph" ,
87
+ "sklearn.feature_extraction.text.strip_accents_ascii" ,
88
+ "sklearn.feature_extraction.text.strip_accents_unicode" ,
89
+ "sklearn.feature_extraction.text.strip_tags" ,
90
+ "sklearn.feature_selection._univariate_selection.chi2" ,
91
+ "sklearn.feature_selection._univariate_selection.f_oneway" ,
A3D4
92
+ "sklearn.feature_selection._univariate_selection.r_regression" ,
93
+ "sklearn.inspection._partial_dependence.partial_dependence" ,
94
+ "sklearn.inspection._plot.partial_dependence.plot_partial_dependence" ,
95
+ "sklearn.isotonic.isotonic_regression" ,
96
+ "sklearn.linear_model._least_angle.lars_path" ,
97
+ "sklearn.linear_model._least_angle.lars_path_gram" ,
98
+ "sklearn.linear_model._omp.orthogonal_mp" ,
99
+ "sklearn.linear_model._omp.orthogonal_mp_gram" ,
100
+ "sklearn.linear_model._ridge.ridge_regression" ,
101
+ "sklearn.manifold._locally_linear.locally_linear_embedding" ,
102
+ "sklearn.manifold._t_sne.trustworthiness" ,
103
+ "sklearn.metrics._classification.accuracy_score" ,
104
+ "sklearn.metrics._classification.balanced_accuracy_score" ,
105
+ "sklearn.metrics._classification.brier_score_loss" ,
106
+ "sklearn.metrics._classification.classification_report" ,
107
+ "sklearn.metrics._classification.cohen_kappa_score" ,
108
+ "sklearn.metrics._classification.confusion_matrix" ,
109
+ "sklearn.metrics._classification.f1_score" ,
110
+ "sklearn.metrics._classification.fbeta_score" ,
111
+ "sklearn.metrics._classification.hamming_loss" ,
112
+ "sklearn.metrics._classification.hinge_loss" ,
113
+ "sklearn.metrics._classification.jaccard_score" ,
114
+ "sklearn.metrics._classification.log_loss" ,
115
+ "sklearn.metrics._classification.precision_recall_fscore_support" ,
116
+ "sklearn.metrics._classification.precision_score" ,
117
+ "sklearn.metrics._classification.recall_score" ,
118
+ "sklearn.metrics._classification.zero_one_loss" ,
119
+ "sklearn.metrics._plot.confusion_matrix.plot_confusion_matrix" ,
120
+ "sklearn.metrics._plot.det_curve.plot_det_curve" ,
121
+ "sklearn.metrics._plot.precision_recall_curve.plot_precision_recall_curve" ,
122
+ "sklearn.metrics._plot.roc_curve.plot_roc_curve" ,
123
+ "sklearn.metrics._ranking.auc" ,
124
+ "sklearn.metrics._ranking.average_precision_score" ,
125
+ "sklearn.metrics._ranking.coverage_error" ,
126
+ "sklearn.metrics._ranking.dcg_score" ,
127
+ "sklearn.metrics._ranking.label_ranking_average_precision_score" ,
128
+ "sklearn.metrics._ranking.label_ranking_loss" ,
129
+ "sklearn.metrics._ranking.ndcg_score" ,
130
+ "sklearn.metrics._ranking.precision_recall_curve" ,
131
+ "sklearn.metrics._ranking.roc_auc_score" ,
132
+ "sklearn.metrics._ranking.roc_curve" ,
133
+ "sklearn.metrics._ranking.top_k_accuracy_score" ,
134
+ "sklearn.metrics._regression.max_error" ,
135
+ "sklearn.metrics._regression.mean_absolute_error" ,
136
+ "sklearn.metrics._regression.mean_pinball_loss" ,
137
+ "sklearn.metrics._scorer.make_scorer" ,
138
+ "sklearn.metrics.cluster._bicluster.consensus_score" ,
139
+ "sklearn.metrics.cluster._supervised.adjusted_mutual_info_score" ,
140
+ "sklearn.metrics.cluster._supervised.adjusted_rand_score" ,
141
+ "sklearn.metrics.cluster._supervised.completeness_score" ,
142
+ "sklearn.metrics.cluster._supervised.entropy" ,
143
+ "sklearn.metrics.cluster._supervised.fowlkes_mallows_score" ,
144
+ "sklearn.metrics.cluster._supervised.homogeneity_completeness_v_measure" ,
145
+ "sklearn.metrics.cluster._supervised.homogeneity_score" ,
146
+ "sklearn.metrics.cluster._supervised.mutual_info_score" ,
147
+ "sklearn.metrics.cluster._supervised.normalized_mutual_info_score" ,
148
+ "sklearn.metrics.cluster._supervised.pair_confusion_matrix" ,
149
+ "sklearn.metrics.cluster._supervised.rand_score" ,
150
+ "sklearn.metrics.cluster._supervised.v_measure_score" ,
151
+ "sklearn.metrics.cluster._unsupervised.davies_bouldin_score" ,
152
+ "sklearn.metrics.cluster._unsupervised.silhouette_samples" ,
153
+ "sklearn.metrics.cluster._unsupervised.silhouette_score" ,
154
+ "sklearn.metrics.pairwise.additive_chi2_kernel" ,
155
+ "sklearn.metrics.pairwise.check_paired_arrays" ,
156
+ "sklearn.metrics.pairwise.check_pairwise_arrays" ,
157
+ "sklearn.metrics.pairwise.chi2_kernel" ,
158
+ "sklearn.metrics.pairwise.cosine_distances" ,
159
+ "sklearn.metrics.pairwise.cosine_similarity" ,
160
+ "sklearn.metrics.pairwise.distance_metrics" ,
161
+ "sklearn.metrics.pairwise.euclidean_distances" ,
162
+ "sklearn.metrics.pairwise.haversine_distances" ,
163
+ "sklearn.metrics.pairwise.kernel_metrics" ,
164
+ "sklearn.metrics.pairwise.laplacian_kernel" ,
165
+ "sklearn.metrics.pairwise.linear_kernel" ,
166
+ "sklearn.metrics.pairwise.manhattan_distances" ,
167
+ "sklearn.metrics.pairwise.nan_euclidean_distances" ,
168
+ "sklearn.metrics.pairwise.paired_cosine_distances" ,
169
+ "sklearn.metrics.pairwise.paired_distances" ,
170
+ "sklearn.metrics.pairwise.paired_euclidean_distances" ,
171
+ "sklearn.metrics.pairwise.paired_manhattan_distances" ,
172
+ "sklearn.metrics.pairwise.pairwise_distances_argmin" ,
173
+ "sklearn.metrics.pairwise.pairwise_distances_argmin_min" ,
174
+ "sklearn.metrics.pairwise.pairwise_distances_chunked" ,
175
+ "sklearn.metrics.pairwise.pairwise_kernels" ,
176
+ "sklearn.metrics.pairwise.polynomial_kernel" ,
177
+ "sklearn.metrics.pairwise.rbf_kernel" ,
178
+ "sklearn.metrics.pairwise.sigmoid_kernel" ,
179
+ "sklearn.model_selection._split.check_cv" ,
180
+ "sklearn.model_selection._split.train_test_split" ,
181
+ "sklearn.model_selection._validation.cross_val_predict" ,
182
+ "sklearn.model_selection._validation.cross_val_score" ,
183
+ "sklearn.model_selection._validation.cross_validate" ,
184
+ "sklearn.model_selection._validation.learning_curve" ,
185
+ "sklearn.model_selection._validation.permutation_test_score" ,
186
+ "sklearn.model_selection._validation.validation_curve" ,
187
+ "sklearn.neighbors._graph.kneighbors_graph" ,
188
+ "sklearn.neighbors._graph.radius_neighbors_graph" ,
189
+ "sklearn.pipeline.make_union" ,
190
+ "sklearn.preprocessing._data.binarize" ,
191
+ "sklearn.preprocessing._data.maxabs_scale" ,
192
+ "sklearn.preprocessing._data.normalize" ,
193
+ "sklearn.preprocessing._data.power_transform" ,
194
+ "sklearn.preprocessing._data.quantile_transform" ,
195
+ "sklearn.preprocessing._data.robust_scale" ,
196
+ "sklearn.preprocessing._data.scale" ,
197
+ "sklearn.preprocessing._label.label_binarize" ,
198
+ "sklearn.random_projection.johnson_lindenstrauss_min_dim" ,
199
+ "sklearn.svm._bounds.l1_min_c" ,
200
+ "sklearn.tree._export.plot_tree" ,
201
+ "sklearn.utils.axis0_safe_slice" ,
202
+ "sklearn.utils.check_pandas_support" ,
203
+ "sklearn.utils.extmath.cartesian" ,
204
+ "sklearn.utils.extmath.density" ,
205
+ "sklearn.utils.extmath.fast_logdet" ,
206
+ "sklearn.utils.extmath.randomized_range_finder" ,
207
+ "sklearn.utils.extmath.randomized_svd" ,
208
+ "sklearn.utils.extmath.safe_sparse_dot" ,
209
+ "sklearn.utils.extmath.squared_norm" ,
210
+ "sklearn.utils.extmath.stable_cumsum" ,
211
+ "sklearn.utils.extmath.svd_flip" ,
212
+ "sklearn.utils.extmath.weighted_mode" ,
213
+ "sklearn.utils.fixes.delayed" ,
214
+ "sklearn.utils.fixes.linspace" ,
215
+ "sklearn.utils.gen_batches" ,
216
+ "sklearn.utils.gen_even_slices" ,
217
+ "sklearn.utils.get_chunk_n_rows" ,
218
+ "sklearn.utils.graph.graph_shortest_path" ,
219
+ "sklearn.utils.graph.single_source_shortest_path_length" ,
220
+ "sklearn.utils.is_scalar_nan" ,
221
+ "sklearn.utils.metaestimators.available_if" ,
222
+ "sklearn.utils.metaestimators.if_delegate_has_method" ,
223
+ "sklearn.utils.multiclass.check_classification_targets" ,
224
+ "sklearn.utils.multiclass.class_distribution" ,
225
+ "sklearn.utils.multiclass.type_of_target" ,
226
+ "sklearn.utils.multiclass.unique_labels" ,
227
+ "sklearn.utils.resample" ,
228
+ "sklearn.utils.safe_mask" ,
229
+ "sklearn.utils.safe_sqr" ,
230
+ "sklearn.utils.shuffle" ,
231
+ "sklearn.utils.sparsefuncs.count_nonzero" ,
232
+ "sklearn.utils.sparsefuncs.csc_median_axis_0" ,
233
+ "sklearn.utils.sparsefuncs.incr_mean_variance_axis" ,
234
+ "sklearn.utils.sparsefuncs.inplace_swap_column" ,
235
+ "sklearn.utils.sparsefuncs.inplace_swap_row" ,
236
+ "sklearn.utils.sparsefuncs.inplace_swap_row_csc" ,
237
+ "sklearn.utils.sparsefuncs.inplace_swap_row_csr" ,
238
+ "sklearn.utils.sparsefuncs.mean_variance_axis" ,
239
+ "sklearn.utils.sparsefuncs.min_max_axis" ,
240
+ "sklearn.utils.tosequence" ,
241
+ "sklearn.utils.validation.as_float_array" ,
242
+ "sklearn.utils.validation.assert_all_finite" ,
243
+ "sklearn.utils.validation.check_is_fitted" ,
244
+ "sklearn.utils.validation.check_memory" ,
245
+ "sklearn.utils.validation.check_random_state" ,
246
+ "sklearn.utils.validation.column_or_1d" ,
247
+ "sklearn.utils.validation.has_fit_parameter" ,
248
+ "sklearn.utils.validation.indexable" ,
249
+ ]
250
+ FUNCTION_DOCSTRING_IGNORE_LIST = set (FUNCTION_DOCSTRING_IGNORE_LIST )
251
+
17
252
18
253
def get_all_methods ():
19
254
estimators = all_estimators ()
@@ -34,6 +269,51 @@ def get_all_methods():
34
269
yield Estimator , method
35
270
36
271
272
+ def _is_checked_function (item ):
273
+ if not inspect .isfunction (item ):
274
+ return False
275
+
276
+ if item .__name__ .startswith ("_" ):
277
+ return False
278
+
279
+ mod = item .__module__
280
+ if not mod .startswith ("sklearn." ) or mod .endswith ("estimator_checks" ):
281
+ return False
282
+
283
+ return True
284
+
285
+
286
+ def get_all_functions_names ():
287
+ """Get all public functions define in the sklearn module"""
288
+ modules_to_ignore = {
289
+ "tests" ,
290
+ "externals" ,
291
+ "setup" ,
292
+ "conftest" ,
293
+ "experimental" ,
294
+ "estimator_checks" ,
295
+ }
296
+
297
+ all_functions_names = set ()
298
+ for module_finder , module_name , ispkg in pkgutil .walk_packages (
299
+ path = sklearn .__path__ , prefix = "sklearn."
300
+ ):
301
+ module_parts = module_name .split ("." )
302
+ if (
303
+ any (part in modules_to_ignore for part in module_parts )
304
+ or "._" in module_name
305
+ ):
306
+ continue
307
+
308
+ module = importlib .import_module (module_name )
309
+ functions = inspect .getmembers (module , _is_checked_function )
310
+ for name , func in functions :
311
+ full_name = f"{ func .__module__ } .{ func .__name__ } "
312
+ all_functions_names .add (full_name )
313
+
314
+ return sorted (all_functions_names )
315
+
316
+
37
317
def filter_errors (errors , method , Estimator = None ):
38
318
"""
39
319
Ignore some errors based on the method type.
@@ -128,6 +408,24 @@ def repr_errors(res, estimator=None, method: Optional[str] = None) -> str:
128
408
return msg
129
409
130
410
411
+ @pytest .mark .parametrize ("function_name" , get_all_functions_names ())
412
+ def test_function_docstring (function_name , request ):
413
+ """Check function docstrings using numpydoc."""
414
+ if function_name in FUNCTION_DOCSTRING_IGNORE_LIST :
415
+ request .applymarker (
416
+ pytest .mark .xfail (run = False , reason = "TODO pass numpydoc validation" )
417
+ )
418
+
419
+ res = numpydoc_validation .validate (function_name )
420
+
421
+ res ["errors" ] = list (filter_errors (res ["errors" ], method = "function" ))
422
+
423
+ if res ["errors" ]:
424
+ msg = repr_errors (res , method = f"Tested function: { function_name } " )
425
+
426
+ raise ValueError (msg )
427
+
428
+
131
429
@pytest .mark .parametrize ("Estimator, method" , get_all_methods ())
132
430
def test_docstring (Estimator , method , request ):
133
431
base_import_path = Estimator .__module__
0 commit comments