10000 TST set random state in check_classifiers_train · scikit-learn/scikit-learn@ac7c88c · GitHub
[go: up one dir, main page]

Skip to content

Commit ac7c88c

Browse files
committed
TST set random state in check_classifiers_train
1 parent 45982e0 commit ac7c88c

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def check_estimator_sparse_data(name, Estimator):
140140
if hasattr(estimator, 'predict_proba'):
141141
estimator.predict_proba(X)
142142
except TypeError as e:
143-
if not 'sparse' in repr(e):
143+
if 'sparse' not in repr(e):
144144
print("Estimator %s doesn't seem to fail gracefully on "
145145
"sparse data: error message state explicitly that "
146146
"sparse input is not supported if this is not the case."
@@ -280,7 +280,7 @@ def check_estimators_nan_inf(name, Estimator):
280280
else:
281281
estimator.fit(X_train, y)
282282
except ValueError as e:
283-
if not 'inf' in repr(e) and not 'NaN' in repr(e):
283+
if 'inf' not in repr(e) and 'NaN' not in repr(e):
284284
print(error_string_fit, Estimator, e)
285285
traceback.print_exc(file=sys.stdout)
286286
raise e
@@ -303,7 +303,7 @@ def check_estimators_nan_inf(name, Estimator):
303303
try:
304304
estimator.predict(X_train)
305305
except ValueError as e:
306-
if not 'inf' in repr(e) and not 'NaN' in repr(e):
306+
if 'inf' not in repr(e) and 'NaN' not in repr(e):
307307
print(error_string_predict, Estimator, e)
308308
traceback.print_exc(file=sys.stdout)
309309
raise e
@@ -318,7 +318,7 @@ def check_estimators_nan_inf(name, Estimator):
318318
try:
319319
estimator.transform(X_train)
320320
except ValueError as e:
321-
if not 'inf' in repr(e) and not 'NaN' in repr(e):
321+
if 'inf' not in repr(e) and 'NaN' not in repr(e):
322322
print(error_string_transform, Estimator, e)
323323
traceback.print_exc(file=sys.stdout)
324324
raise e
@@ -444,7 +444,7 @@ def check_classifiers_one_label(name, Classifier):
444444
try:
445445
classifier.fit(X_train, y)
446446
except ValueError as e:
447-
if not 'class' in repr(e):
447+
if 'class' not in repr(e):
448448
print(error_string_fit, Classifier, e)
449449
traceback.print_exc(file=sys.stdout)
450450
raise e
@@ -479,6 +479,7 @@ def check_classifiers_train(name, Classifier):
479479
if name in ['BernoulliNB', 'MultinomialNB']:
480480
X -= X.min()
481481
set_fast_parameters(classifier)
482+
set_random_state(classifier)
482483
# raises error on malformed input for fit
483484
assert_raises(ValueError, classifier.fit, X, y[:-1])
484485

@@ -497,7 +498,7 @@ def check_classifiers_train(name, Classifier):
497498
assert_raises(ValueError, classifier.predict, X.T)
498499
if hasattr(classifier, "decision_function"):
499500
try:
500-
# decision_function agrees with predict
501+
# decision_function agrees with predict
501502
decision = classifier.decision_function(X)
502503
if n_classes is 2:
503504
assert_equal(decision.shape, (n_samples,))

0 commit comments

Comments
 (0)
0