@@ -607,21 +607,23 @@ def check_estimator(estimator=None, generate_only=False, *, legacy: bool = True)
607
607
608
608
name = type (estimator ).__name__
609
609
610
- def checks_generator ():
610
+ def checks_generator (reference_estimator ):
611
611
# we first need to check if the estimator is cloneable for the rest of the tests
612
612
# to run
613
613
yield estimator , partial (check_estimator_cloneable , name )
614
614
for check in _yield_all_checks (estimator , legacy = legacy ):
615
- for check_instance in _yield_instances_for_check (check , estimator ):
616
- maybe_skipped_check = _maybe_skip (check_instance , check )
617
- yield check_instance , partial (maybe_skipped_check , name )
615
+ for check_specific_estimator in _yield_instances_for_check (
616
+ check , reference_estimator
617
+ ):
618
+ maybe_skipped_check = _maybe_skip (check_specific_estimator , check )
619
+ yield check_specific_estimator , partial (maybe_skipped_check , name )
618
620
619
621
if generate_only :
620
- return checks_generator ()
622
+ return checks_generator (estimator )
621
623
622
- for estimator , check in checks_generator ():
624
+ for check_specific_estimator , check in checks_generator (estimator ):
623
625
try :
624
- check (estimator )
626
+ check (check_specific_estimator )
625
627
except SkipTest as exception :
626
628
# SkipTest is thrown when pandas can't be imported, or by checks
627
629
# that are in the xfail_checks tag
0 commit comments