@@ -126,8 +126,7 @@ def test_normalized_output(metric_name):
126
126
# 0.22 AMI and NMI changes
127
127
@pytest .mark .filterwarnings ('ignore::FutureWarning' )
128
128
@pytest .mark .parametrize (
129
- "metric_name" ,
130
- dict (SUPERVISED_METRICS , ** UNSUPERVISED_METRICS )
129
+ "metric_name" , dict (SUPERVISED_METRICS , ** UNSUPERVISED_METRICS )
131
130
)
132
131
def test_permute_labels (metric_name ):
133
132
# All clustering metrics do not change score due to permutations of labels
@@ -150,11 +149,10 @@ def test_permute_labels(metric_name):
150
149
# 0.22 AMI and NMI changes
151
150
@pytest .mark .filterwarnings ('ignore::FutureWarning' )
152
151
@pytest .mark .parametrize (
153
- "metric_name" ,
154
- dict (SUPERVISED_METRICS , ** UNSUPERVISED_METRICS )
152
+ "metric_name" , dict (SUPERVISED_METRICS , ** UNSUPERVISED_METRICS )
155
153
)
156
154
# For all clustering metrics Input parameters can be both
157
- # in the form of arrays lists, positive, negetive or string
155
+ # in the form of arrays lists, positive, negative or string
158
156
def test_format_invariance (metric_name ):
159
157
y_true = [0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ]
160
158
y_pred = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]
@@ -183,3 +181,29 @@ def generate_formats(y):
183
181
y_true_gen = generate_formats (y_true )
184
182
for (y_true_fmt , fmt_name ) in y_true_gen :
185
183
assert score_1 == metric (X , y_true_fmt )
184
+
185
+
186
+ @pytest .mark .parametrize ("metric" , SUPERVISED_METRICS .values ())
187
+ def test_single_sample (metric ):
188
+ # only the supervised metrics support single sample
189
+ for i , j in [(0 , 0 ), (0 , 1 ), (1 , 0 ), (1 , 1 )]:
190
+ metric ([i ], [j ])
191
+
192
+
193
+ @pytest .mark .parametrize (
194
+ "metric_name, metric_func" ,
195
+ dict (SUPERVISED_METRICS , ** UNSUPERVISED_METRICS ).items ()
196
+ )
197
+ def test_inf_nan_input (metric_name , metric_func ):
198
+ if metric_name in SUPERVISED_METRICS :
199
+ invalids = [([0 , 1 ], [np .inf , np .inf ]),
200
+ ([0 , 1 ], [np .nan , np .nan ]),
201
+ ([0 , 1 ], [np .nan , np .inf ])]
202
+ else :
203
+ X = np .random .randint (10 , size = (2 , 10 ))
204
+ invalids = [(X , [np .inf , np .inf ]),
205
+ (X , [np .nan , np .nan ]),
206
+ (X , [np .nan , np .inf ])]
207
+ with pytest .raises (ValueError , match = 'contains NaN, infinity' ):
208
+ for args in invalids :
209
+ metric_func (* args )
0 commit comments