1
1
import re
2
2
from inspect import signature
3
+ import pkgutil
4
+ import inspect
5
+ import importlib
3
6
from typing import Optional
4
7
5
8
import pytest
6
9
from sklearn .utils import all_estimators
10
+ import sklearn
7
11
8
12
numpydoc_validation = pytest .importorskip ("numpydoc.validate" )
9
13
30
34
"TransformedTargetRegressor" ,
31
35
]
32
36
37
+ FUNCTION_DOCSTRING_IGNORE_LIST = [
38
+ "sklearn._config.config_context" ,
39
+ "sklearn._config.get_config" ,
40
+ "sklearn.base.clone" ,
41
+ "sklearn.cluster._affinity_propagation.affinity_propagation" ,
42
+ "sklearn.cluster._agglomerative.linkage_tree" ,
43
+ "sklearn.cluster._kmeans.k_means" ,
44
+ "sklearn.cluster._kmeans.kmeans_plusplus" ,
45
+ "sklearn.cluster._mean_shift.estimate_bandwidth" ,
46
+ "sklearn.cluster._mean_shift.get_bin_seeds" ,
47
+ "sklearn.cluster._mean_shift.mean_shift" ,
48
+ "sklearn.cluster._optics.cluster_optics_dbscan" ,
49
+ "sklearn.cluster._optics.cluster_optics_xi" ,
50
+ "sklearn.cluster._optics.compute_optics_graph" ,
51
+ "sklearn.cluster._spectral.spectral_clustering" ,
52
+ "sklearn.compose._column_transformer.make_column_transformer" ,
53
+ "sklearn.covariance._empirical_covariance.empirical_covariance" ,
54
+ "sklearn.covariance._empirical_covariance.log_likelihood" ,
55
+ "sklearn.covariance._graph_lasso.graphical_lasso" ,
56
+ "sklearn.covariance._robust_covariance.fast_mcd" ,
57
+ "sklearn.covariance._shrunk_covariance.ledoit_wolf" ,
58
+ "sklearn.covariance._shrunk_covariance.ledoit_wolf_shrinkage" ,
59
+ "sklearn.covariance._shrunk_covariance.shrunk_covariance" ,
60
+ "sklearn.datasets._base.get_data_home" ,
61
+ "sklearn.datasets._base.load_boston" ,
62
+ "sklearn.datasets._base.load_breast_cancer" ,
63
+ "sklearn.datasets._base.load_diabetes" ,
64
+ "sklearn.datasets._base.load_digits" ,
65
+ "sklearn.datasets._base.load_files" ,
66
+ "sklearn.datasets._base.load_iris" ,
67
+ "sklearn.datasets._base.load_linnerud" ,
68
+ "sklearn.datasets._base.load_sample_image" ,
69
+ "sklearn.datasets._base.load_wine" ,
70
+ "sklearn.datasets._california_housing.fetch_california_housing" ,
71
+ "sklearn.datasets._covtype.fetch_covtype" ,
72
+ "sklearn.datasets._kddcup99.fetch_kddcup99" ,
73
+ "sklearn.datasets._lfw.fetch_lfw_pairs" ,
74
+ "sklearn.datasets._lfw.fetch_lfw_people" ,
75
+ "sklearn.datasets._olivetti_faces.fetch_olivetti_faces" ,
76
+ "sklearn.datasets._openml.fetch_openml" ,
77
+ "sklearn.datasets._rcv1.fetch_rcv1" ,
78
+ "sklearn.datasets._samples_generator.make_biclusters" ,
79
+ "sklearn.datasets._samples_generator.make_blobs" ,
80
+ "sklearn.datasets._samples_generator.make_checkerboard" ,
81
+ "sklearn.datasets._samples_generator.make_classification" ,
82
+ "sklearn.datasets._samples_generator.make_gaussian_quantiles" ,
83
+ "sklearn.datasets._samples_generator.make_hastie_10_2" ,
84
+ "sklearn.datasets._samples_generator.make_multilabel_classification" ,
85
+ "sklearn.datasets._samples_generator.make_regression" ,
86
+ "sklearn.datasets._samples_generator.make_sparse_coded_signal" ,
87
+ "sklearn.datasets._samples_generator.make_sparse_spd_matrix" ,
88
+ "sklearn.datasets._samples_generator.make_spd_matrix" ,
89
+ "sklearn.datasets._species_distributions.fetch_species_distributions" ,
90
+ "sklearn.datasets._svmlight_format_io.dump_svmlight_file" ,
91
+ "sklearn.datasets._svmlight_format_io.load_svmlight_file" ,
92
+ "sklearn.datasets._svmlight_format_io.load_svmlight_files" ,
93
+ "sklearn.datasets._twenty_newsgroups.fetch_20newsgroups" ,
94
+ "sklearn.decomposition._dict_learning.dict_learning" ,
95
+ "sklearn.decomposition._dict_learning.dict_learning_online" ,
96
+ "sklearn.decomposition._dict_learning.sparse_encode" ,
97
+ "sklearn.decomposition._fastica.fastica" ,
98
+ "sklearn.decomposition._nmf.non_negative_factorization" ,
99
+ "sklearn.externals._packaging.version.parse" ,
100
+ "sklearn.feature_extraction.image.extract_patches_2d" ,
101
+ "sklearn.feature_extraction.image.grid_to_graph" ,
102
+ "sklearn.feature_extraction.image.img_to_graph" ,
103
+ "sklearn.feature_extraction.text.strip_accents_ascii" ,
104
+ "sklearn.feature_extraction.text.strip_accents_unicode" ,
105
+ <
B41A
div class="diff-text-inner"> "sklearn.feature_extraction.text.strip_tags" ,
106
+ "sklearn.feature_selection._univariate_selection.chi2" ,
107
+ "sklearn.feature_selection._univariate_selection.f_oneway" ,
108
+ "sklearn.feature_selection._univariate_selection.r_regression" ,
109
+ "sklearn.inspection._partial_dependence.partial_dependence" ,
110
+ "sklearn.inspection._plot.partial_dependence.plot_partial_dependence" ,
111
+ "sklearn.isotonic.isotonic_regression" ,
112
+ "sklearn.linear_model._least_angle.lars_path" ,
113
+ "sklearn.linear_model._least_angle.lars_path_gram" ,
114
+ "sklearn.linear_model._omp.orthogonal_mp" ,
115
+ "sklearn.linear_model._omp.orthogonal_mp_gram" ,
116
+ "sklearn.linear_model._ridge.ridge_regression" ,
117
+ "sklearn.manifold._locally_linear.locally_linear_embedding" ,
118
+ "sklearn.manifold._t_sne.trustworthiness" ,
119
+ "sklearn.metrics._classification.accuracy_score" ,
120
+ "sklearn.metrics._classification.balanced_accuracy_score" ,
121
+ "sklearn.metrics._classification.brier_score_loss" ,
122
+ "sklearn.metrics._classification.classification_report" ,
123
+ "sklearn.metrics._classification.cohen_kappa_score" ,
124
+ "sklearn.metrics._classification.confusion_matrix" ,
125
+ "sklearn.metrics._classification.f1_score" ,
126
+ "sklearn.metrics._classification.fbeta_score" ,
127
+ "sklearn.metrics._classification.hamming_loss" ,
128
+ "sklearn.metrics._classification.hinge_loss" ,
129
+ "sklearn.metrics._classification.jaccard_score" ,
130
+ "sklearn.metrics._classification.log_loss" ,
131
+ "sklearn.metrics._classification.precision_recall_fscore_support" ,
132
+ "sklearn.metrics._classification.precision_score" ,
133
+ "sklearn.metrics._classification.recall_score" ,
134
+ "sklearn.metrics._classification.zero_one_loss" ,
135
+ "sklearn.metrics._plot.confusion_matrix.plot_confusion_matrix" ,
136
+ "sklearn.metrics._plot.det_curve.plot_det_curve" ,
137
+ "sklearn.metrics._plot.precision_recall_curve.plot_precision_recall_curve" ,
138
+ "sklearn.metrics._plot.roc_curve.plot_roc_curve" ,
139
+ "sklearn.metrics._ranking.auc" ,
140
+ "sklearn.metrics._ranking.average_precision_score" ,
141
+ "sklearn.metrics._ranking.coverage_error" ,
142
+ "sklearn.metrics._ranking.dcg_score" ,
143
+ "sklearn.metrics._ranking.label_ranking_average_precision_score" ,
144
+ "sklearn.metrics._ranking.label_ranking_loss" ,
145
+ "sklearn.metrics._ranking.ndcg_score" ,
146
+ "sklearn.metrics._ranking.precision_recall_curve" ,
147
+ "sklearn.metrics._ranking.roc_auc_score" ,
148
+ "sklearn.metrics._ranking.roc_curve" ,
149
+ "sklearn.metrics._ranking.top_k_accuracy_score" ,
150
+ "sklearn.metrics._regression.max_error" ,
151
+ "sklearn.metrics._regression.mean_absolute_error" ,
152
+ "sklearn.metrics._regression.mean_pinball_loss" ,
153
+ "sklearn.metrics._scorer.make_scorer" ,
154
+ "sklearn.metrics.cluster._bicluster.consensus_score" ,
155
+ "sklearn.metrics.cluster._supervised.adjusted_mutual_info_score" ,
156
+ "sklearn.metrics.cluster._supervised.adjusted_rand_score" ,
157
+ "sklearn.metrics.cluster._supervised.completeness_score" ,
158
+ "sklearn.metrics.cluster._supervised.entropy" ,
159
+ "sklearn.metrics.cluster._supervised.fowlkes_mallows_score" ,
160
+ "sklearn.metrics.cluster._supervised.homogeneity_completeness_v_measure" ,
161
+ "sklearn.metrics.cluster._supervised.homogeneity_score" ,
162
+ "sklearn.metrics.cluster._supervised.mutual_info_score" ,
163
+ "sklearn.metrics.cluster._supervised.normalized_mutual_info_score" ,
164
+ "sklearn.metrics.cluster._supervised.pair_confusion_matrix" ,
165
+ "sklearn.metrics.cluster._supervised.rand_score" ,
166
+ "sklearn.metrics.cluster._supervised.v_measure_score" ,
167
+ "sklearn.metrics.cluster._unsupervised.davies_bouldin_score" ,
168
+ "sklearn.metrics.cluster._unsupervised.silhouette_samples" ,
169
+ "sklearn.metrics.cluster._unsupervised.silhouette_score" ,
170
+ "sklearn.metrics.pairwise.additive_chi2_kernel" ,
171
+ "sklearn.metrics.pairwise.check_paired_arrays" ,
172
+ "sklearn.metrics.pairwise.check_pairwise_arrays" ,
173
+ "sklearn.metrics.pairwise.chi2_kernel" ,
174
+ "sklearn.metrics.pairwise.cosine_distances" ,
175
+ "sklearn.metrics.pairwise.cosine_similarity" ,
176
+ "sklearn.metrics.pairwise.distance_metrics" ,
177
+ "sklearn.metrics.pairwise.euclidean_distances" ,
178
+ "sklearn.metrics.pairwise.haversine_distances" ,
179
+ "sklearn.metrics.pairwise.kernel_metrics" ,
180
+ "sklearn.metrics.pairwise.laplacian_kernel" ,
181
+ "sklearn.metrics.pairwise.linear_kernel" ,
182
+ "sklearn.metrics.pairwise.manhattan_distances" ,
183
+ "sklearn.metrics.pairwise.nan_euclidean_distances" ,
184
+ "sklearn.metrics.pairwise.paired_cosine_distances" ,
185
+ "sklearn.metrics.pairwise.paired_distances" ,
186
+ "sklearn.metrics.pairwise.paired_euclidean_distances" ,
187
+ "sklearn.metrics.pairwise.paired_manhattan_distances" ,
188
+ "sklearn.metrics.pairwise.pairwise_distances_argmin" ,
189
+ "sklearn.metrics.pairwise.pairwise_distances_argmin_min" ,
190
+ "sklearn.metrics.pairwise.pairwise_distances_chunked" ,
191
+ "sklearn.metrics.pairwise.pairwise_kernels" ,
192
+ "sklearn.metrics.pairwise.polynomial_kernel" ,
193
+ "sklearn.metrics.pairwise.rbf_kernel" ,
194
+ "sklearn.metrics.pairwise.sigmoid_kernel" ,
195
+ "sklearn.model_selection._split.check_cv" ,
196
+ "sklearn.model_selection._split.train_test_split" ,
197
+ "sklearn.model_selection._validation.cross_val_predict" ,
198
+ "sklearn.model_selection._validation.cross_val_score" ,
199
+ "sklearn.model_selection._validation.cross_validate" ,
200
+ "sklearn.model_selection._validation.learning_curve" ,
201
+ "sklearn.model_selection._validation.permutation_test_score" ,
202
+ "sklearn.model_selection._validation.validation_curve" ,
203
+ "sklearn.neighbors._graph.kneighbors_graph" ,
204
+ "sklearn.neighbors._graph.radius_neighbors_graph" ,
205
+ "sklearn.pipeline.make_union" ,
206
+ "sklearn.preprocessing._data.binarize" ,
207
+ "sklearn.preprocessing._data.maxabs_scale" ,
208
+ "sklearn.preprocessing._data.normalize" ,
209
+ "sklearn.preprocessing._data.power_transform" ,
210
+ "sklearn.preprocessing._data.quantile_transform" ,
211
+ "sklearn.preprocessing._data.robust_scale" ,
212
+ "sklearn.preprocessing._data.scale" ,
213
+ "sklearn.preprocessing._label.label_binarize" ,
214
+ "sklearn.random_projection.johnson_lindenstrauss_min_dim" ,
215
+ "sklearn.svm._bounds.l1_min_c" ,
216
+ "sklearn.tree._export.plot_tree" ,
217
+ "sklearn.utils.axis0_safe_slice" ,
218
+ "sklearn.utils.check_pandas_support" ,
219
+ "sklearn.utils.extmath.cartesian" ,
220
+ "sklearn.utils.extmath.density" ,
221
+ "sklearn.utils.extmath.fast_logdet" ,
222
+ "sklearn.utils.extmath.randomized_range_finder" ,
223
+ "sklearn.utils.extmath.randomized_svd" ,
224
+ "sklearn.utils.extmath.safe_sparse_dot" ,
225
+ "sklearn.utils.extmath.squared_norm" ,
226
+ "sklearn.utils.extmath.stable_cumsum" ,
227
+ "sklearn.utils.extmath.svd_flip" ,
228
+ "sklearn.utils.extmath.weighted_mode" ,
229
+ "sklearn.utils.fixes.delayed" ,
230
+ "sklearn.utils.fixes.linspace" ,
231
+ "sklearn.utils.gen_batches" ,
232
+ "sklearn.utils.gen_even_slices" ,
233
+ "sklearn.utils.get_chunk_n_rows" ,
234
+ "sklearn.utils.graph.graph_shortest_path" ,
235
+ "sklearn.utils.graph.single_source_shortest_path_length" ,
236
+ "sklearn.utils.is_scalar_nan" ,
237
+ "sklearn.utils.metaestimators.available_if" ,
238
+ "sklearn.utils.metaestimators.if_delegate_has_method" ,
239
+ "sklearn.utils.multiclass.check_classification_targets" ,
240
+ "sklearn.utils.multiclass.class_distribution" ,
241
+ "sklearn.utils.multiclass.type_of_target" ,
242
+ "sklearn.utils.multiclass.unique_labels" ,
243
+ "sklearn.utils.resample" ,
244
+ "sklearn.utils.safe_mask" ,
245
+ "sklearn.utils.safe_sqr" ,
246
+ "sklearn.utils.shuffle" ,
247
+ "sklearn.utils.sparsefuncs.count_nonzero" ,
248
+ "sklearn.utils.sparsefuncs.csc_median_axis_0" ,
249
+ "sklearn.utils.sparsefuncs.incr_mean_variance_axis" ,
250
+ "sklearn.utils.sparsefuncs.inplace_swap_column" ,
251
+ "sklearn.utils.sparsefuncs.inplace_swap_row" ,
252
+ "sklearn.utils.sparsefuncs.inplace_swap_row_csc" ,
253
+ "sklearn.utils.sparsefuncs.inplace_swap_row_csr" ,
254
+ "sklearn.utils.sparsefuncs.mean_variance_axis" ,
255
+ "sklearn.utils.sparsefuncs.min_max_axis" ,
256
+ "sklearn.utils.tosequence" ,
257
+ "sklearn.utils.validation.as_float_array" ,
258
+ "sklearn.utils.validation.assert_all_finite" ,
259
+ "sklearn.utils.validation.check_is_fitted" ,
260
+ "sklearn.utils.validation.check_memory" ,
261
+ "sklearn.utils.validation.check_random_state" ,
262
+ "sklearn.utils.validation.column_or_1d" ,
263
+ "sklearn.utils.validation.has_fit_parameter" ,
264
+ "sklearn.utils.validation.indexable" ,
265
+ ]
266
+ FUNCTION_DOCSTRING_IGNORE_LIST = set (FUNCTION_DOCSTRING_IGNORE_LIST )
267
+
33
268
34
269
def get_all_methods ():
35
270
estimators = all_estimators ()
@@ -50,6 +285,51 @@ def get_all_methods():
50
285
yield Estimator , method
51
286
52
287
288
+ def _is_checked_function (item ):
289
+ if not inspect .isfunction (item ):
290
+ return False
291
+
292
+ if item .__name__ .startswith ("_" ):
293
+ return False
294
+
295
+ mod = item .__module__
296
+ if not mod .startswith ("sklearn." ) or mod .endswith ("estimator_checks" ):
297
+ return False
298
+
299
+ return True
300
+
301
+
302
+ def get_all_functions_names ():
303
+ """Get all public functions define in the sklearn module"""
304
+ modules_to_ignore = {
305
+ "tests" ,
306
+ "externals" ,
307
+ "setup" ,
308
+ "conftest" ,
309
+ "experimental" ,
310
+ "estimator_checks" ,
311
+ }
312
+
313
+ all_functions_names = set ()
314
+ for module_finder , module_name , ispkg in pkgutil .walk_packages (
315
+ path = sklearn .__path__ , prefix = "sklearn."
316
+ ):
317
+ module_parts = module_name .split ("." )
318
+ if (
319
+ any (part in modules_to_ignore for part in module_parts )
320
+ or "._" in module_name
321
+ ):
322
+ continue
323
+
324
+ module = importlib .import_module (module_name )
325
+ functions = inspect .getmembers (module , _is_checked_function )
326
+ for name , func in functions :
327
+ full_name = f"{ func .__module__ } .{ func .__name__ } "
328
+ all_functions_names .add (full_name )
329
+
330
+ return sorted (all_functions_names )
331
+
332
+
53
333
def filter_errors (errors , method , Estimator = None ):
54
334
"""
55
335
Ignore some errors based on the method type.
@@ -144,6 +424,24 @@ def repr_errors(res, estimator=None, method: Optional[str] = None) -> str:
144
424
return msg
145
425
146
426
427
+ @pytest .mark .parametrize ("function_name" , get_all_functions_names ())
428
+ def test_function_docstring (function_name , request ):
429
+ """Check function docstrings using numpydoc."""
430
+ if function_name in FUNCTION_DOCSTRING_IGNORE_LIST :
431
+ request .applymarker (
432
+ pytest .mark .xfail (run = False , reason = "TODO pass numpydoc validation" )
433
+ )
434
+
435
+ res = numpydoc_validation .validate (function_name )
436
+
437
+ res ["errors" ] = list (filter_errors (res ["errors" ], method = "function" ))
438
+
439
+ if res ["errors" ]:
440
+ msg = repr_errors (res , method = f"Tested function: { function_name } " )
441
+
442
+ raise ValueError (msg )
443
+
444
+
147
445
@pytest .mark .parametrize ("Estimator, method" , get_all_methods ())
148
446
def test_docstring (Estimator , method , request ):
149
447
base_import_path = Estimator .__module__
0 commit comments