@@ -1697,7 +1697,7 @@ def test_validation_curve_cv_splits_consistency():
1697
1697
assert_array_almost_equal (np .array (scores3 ), np .array (scores1 ))
1698
1698
1699
1699
1700
- def test_validation_curve_fit_params ():
1700
+ def test_validation_curve_params ():
1701
1701
X = np
1CF5
.arange (100 ).reshape (10 , 10 )
1702
1702
y = np .array ([0 ] * 5 + [1 ] * 5 )
1703
1703
clf = CheckingClassifier (expected_sample_weight = True )
@@ -1722,7 +1722,7 @@ def test_validation_curve_fit_params():
1722
1722
param_name = "foo_param" ,
1723
1723
param_range = [1 , 2 , 3 ],
1724
1724
error_score = "raise" ,
1725
- fit_params = {"sample_weight" : np .ones (1 )},
1725
+ params = {"sample_weight" : np .ones (1 )},
1726
1726
)
1727
1727
validation_curve (
1728
1728
clf ,
@@ -1731,7 +1731,7 @@ def test_validation_curve_fit_params():
1731
1731
param_name = "foo_param" ,
1732
1732
param_range = [1 , 2 , 3 ],
1733
1733
error_score = "raise" ,
1734
- fit_params = {"sample_weight" : np .ones (10 )},
1734
+ params = {"sample_weight" : np .ones (10 )},
1735
1735
)
1736
1736
1737
1737
@@ -2482,29 +2482,54 @@ def test_cross_validate_return_indices(global_random_seed):
2482
2482
assert_array_equal (test_indices [split_idx ], expected_test_idx )
2483
2483
2484
2484
2485
- # Tests for metadata routing in cross_val* and learning_curve
2486
- # ===========================================================
2485
+ # Tests for metadata routing in cross_val* and in *curve
2486
+ # ======================================================
2487
2487
2488
2488
2489
2489
# TODO(1.6): remove `cross_validate` and `cross_val_predict` from this test in 1.6 and
2490
- # `learning_curve` in 1.8
2491
- @pytest .mark .parametrize ("func" , [cross_validate , cross_val_predict , learning_curve ])
2492
- def test_fit_param_deprecation (func ):
2490
+ # `learning_curve` and `validation_curve` in 1.8
2491
+ @pytest .mark .parametrize (
2492
+ "func, extra_args" ,
2493
+ [
2494
+ (cross_validate , {}),
2495
+ (cross_val_score , {}),
2496
+ (cross_val_predict , {}),
2497
+ (learning_curve , {}),
2498
+ (validation_curve , {"param_name" : "alpha" , "param_range" : np .array ([1 ])}),
2499
+ ],
2500
+ )
2501
+ def test_fit_param_deprecation (func , extra_args ):
2493
2502
"""Check that we warn about deprecating `fit_params`."""
2494
2503
with pytest .warns (FutureWarning , match = "`fit_params` is deprecated" ):
2495
- func (estimator = ConsumingClassifier (), X = X , y = y , cv = 2 , fit_params = {})
2504
+ func (
2505
+ estimator = ConsumingClassifier (), X = X , y = y , cv = 2 , fit_params = {}, ** extra_args
2506
+ )
2496
2507
2497
2508
with pytest .raises (
2498
2509
ValueError , match = "`params` and `fit_params` cannot both be provided"
2499
2510
):
2500
- func (estimator = ConsumingClassifier (), X = X , y = y , fit_params = {}, params = {})
2511
+ func (
2512
+ estimator = ConsumingClassifier (),
2513
+ X = X ,
2514
+ y = y ,
2515
+ fit_params = {},
2516
+ params = {},
2517
+ ** extra_args ,
2518
+ )
2501
2519
2502
2520
2503
2521
@pytest .mark .usefixtures ("enable_slep006" )
2504
2522
@pytest .mark .parametrize (
2505
- "func" , [cross_validate , cross_val_score , cross_val_predict , learning_curve ]
2523
+ "func, extra_args" ,
2524
+ [
2525
+ (cross_validate , {}),
2526
+ (cross_val_score , {}),
2527
+ (cross_val_predict , {}),
2528
+ (learning_curve , {}),
2529
+ (validation_curve , {"param_name" : "alpha" , "param_range" : np .array ([1 ])}),
2530
+ ],
2506
2531
)
2507
- def test_groups_with_routing_validation (func ):
2532
+ def test_groups_with_routing_validation (func , extra_args ):
2508
2533
"""Check that we raise an error if `groups` are passed to the cv method instead
2509
2534
of `params` when metadata routing is enabled.
2510
2535
"""
@@ -2514,14 +2539,22 @@ def test_groups_with_routing_validation(func):
2514
2539
X = X ,
2515
2540
y = y ,
2516
2541
groups = [],
2542
+ ** extra_args ,
2517
2543
)
2518
2544
2519
2545
2520
2546
@pytest .mark .usefixtures ("enable_slep006" )
2521
2547
@pytest .mark .parametrize (
2522
- "func" , [cross_validate , cross_val_score , cross_val_predict , learning_curve ]
2548
+ "func, extra_args" ,
2549
+ [
2550
+ (cross_validate , {}),
2551
+ (cross_val_score , {}),
2552
+ (cross_val_predict , {}),
2553
+ (learning_curve , {}),
2554
+ (validation_curve , {"param_name" : "alpha" , "param_range" : np .array ([1 ])}),
2555
+ ],
2523
2556
)
2524
- def test_passed_unrequested_metadata (func ):
2557
+ def test_passed_unrequested_metadata (func , extra_args ):
2525
2558
"""Check that we raise an error when passing metadata that is not
2526
2559
requested."""
2527
2560
err_msg = re .escape ("but are not explicitly set as requested or not requested" )
@@ -2531,14 +2564,22 @@ def test_passed_unrequested_metadata(func):
2531
2564
X = X ,
2532
2565
y = y ,
2533
2566
params = dict (metadata = []),
2567
+ ** extra_args ,
2534
2568
)
2535
2569
2536
2570
2537
2571
@pytest .mark .usefixtures ("enable_slep006" )
2538
2572
@pytest .mark .parametrize (
2539
- "func" , [cross_validate , cross_val_score , cross_val_predict , learning_curve ]
2573
+ "func, extra_args" ,
2574
+ [
2575
+ (cross_validate , {}),
2576
+ (cross_val_score , {}),
2577
+ (cross_val_predict , {}),
2578
+ (learning_curve , {}),
2579
+ (validation_curve , {"param_name" : "alpha" , "param_range" : np .array ([1 ])}),
2580
+ ],
2540
2581
)
2541
- def test_validation_functions_routing (func ):
2582
+ def test_validation_functions_routing (func , extra_args ):
2542
2583
"""Check that the respective cv method is properly dispatching the metadata
2543
2584
to the consumer."""
2544
2585
scorer_registry = _Registry ()
@@ -2563,12 +2604,11 @@ def test_validation_functions_routing(func):
2563
2604
fit_sample_weight = rng .rand (n_samples )
2564
2605
fit_metadata = rng .rand (n_samples )
2565
2606
2566
- extra_params = {
2607
+ scoring_args = {
2567
2608
cross_validate : dict (scoring = dict (my_scorer = scorer , accuracy = "accuracy" )),
2568
- # cross_val_score and learning_curve don't support multiple scorers:
2569
2609
cross_val_score : dict (scoring = scorer ),
2570
2610
learning_curve : dict (scoring = scorer ),
2571
- # cross_val_predict doesn't need a scorer
2611
+ validation_curve : dict ( scoring = scorer ),
2572
2612
cross_val_predict : dict (),
2573
2613
}
2574
2614
@@ -2590,7 +2630,8 @@ def test_validation_functions_routing(func):
2590
2630
X = X ,
2591
2631
y = y ,
2592
2632
cv = splitter ,
2593
- ** extra_params [func ],
2633
+ ** scoring_args [func ],
2634
+ ** extra_args ,
2594
2635
params = params ,
2595
2636
)
2596
2637
0 commit comments