8000 FIX test_loss_boundary · scikit-learn/scikit-learn@dea5fd1 · GitHub
[go: up one dir, main page]

Skip to content

Commit dea5fd1

Browse files
committed
FIX test_loss_boundary
1 parent 38c110b commit dea5fd1

File tree

1 file changed

+24
-42
lines changed

1 file changed

+24
-42
lines changed

sklearn/_loss/tests/test_loss.py

Lines changed: 24 additions & 42 deletions
8000
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,10 @@ def random_y_true_raw_prediction(
7171
high = min(high, y_bound[1])
7272
y_true = rng.uniform(low, high, size=n_samples)
7373
# set some values at special boundaries
74-
if (
75-
loss.interval_y_true.low == 0
76-
and loss.interval_y_true.low_inclusive
77-
):
74+
if loss.interval_y_true.low == 0 and loss.interval_y_true.low_inclusive:
7875
y_true[:: (n_samples // 3)] = 0
79-
if (
80-
loss.interval_y_true.high == 1
81-
and loss.interval_y_true.high_inclusive
82-
):
83-
y_true[1:: (n_samples // 3)] = 1
76+
if loss.interval_y_true.high == 1 and loss.interval_y_true.high_inclusive:
77+
y_true[1 :: (n_samples // 3)] = 1
8478

8579
return y_true, raw_prediction
8680

@@ -96,9 +90,7 @@ def numerical_derivative(func, x, eps):
9690
f_minus_1h = func(x - h)
9791
f_plus_1h = func(x + h)
9892
f_plus_2h = func(x + 2 * h)
99- 10000
return (-f_plus_2h + 8 * f_plus_1h - 8 * f_minus_1h + f_minus_2h) / (
100-
12.0 * eps
101-
)
93+
return (-f_plus_2h + 8 * f_plus_1h - 8 * f_minus_1h + f_minus_2h) / (12.0 * eps)
10294

10395

10496
@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
@@ -119,14 +111,15 @@ def test_loss_boundary(loss):
119111

120112
assert loss.in_y_true_range(y_true)
121113

114+
n = y_true.shape[0]
122115
low, high = _inclusive_low_high(loss.interval_y_pred)
123116
if loss.is_multiclass:
124-
y_pred = np.empty((10, 3))
125-
y_pred[:, 0] = np.linspace(low, high, num=10)
117+
y_pred = np.empty((n, 3))
118+
y_pred[:, 0] = np.linspace(low, high, num=n)
126119
y_pred[:, 1] = 0.5 * (1 - y_pred[:, 0])
127120
y_pred[:, 2] = 0.5 * (1 - y_pred[:, 0])
128121
else:
129-
y_pred = np.linspace(low, high, num=10)
122+
y_pred = np.linspace(low, high, num=n)
130123

131124
assert loss.in_y_pred_range(y_pred)
132125

@@ -185,8 +178,7 @@ def test_loss_boundary_y_true(loss, y_true_success, y_true_fail):
185178

186179

187180
@pytest.mark.parametrize(
188-
"loss, y_pred_success, y_pred_fail",
189-
Y_COMMON_PARAMS + Y_PRED_PARAMS # type: ignore
181+
"loss, y_pred_success, y_pred_fail", Y_COMMON_PARAMS + Y_PRED_PARAMS # type: ignore
190182
)
191183
def test_loss_boundary_y_pred(loss, y_pred_success, y_pred_fail):
192184
"""Test boundaries of y_pred for loss functions."""
@@ -203,9 +195,7 @@ def test_loss_boundary_y_pred(loss, y_pred_success, y_pred_fail):
203195
@pytest.mark.parametrize("out1", [None, 1])
204196
@pytest.mark.parametrize("out2", [None, 1])
205197
@pytest.mark.parametrize("n_threads", [1, 2])
206-
def test_loss_dtype(
207-
loss, dtype_in, dtype_out, sample_weight, out1, out2, n_threads
208-
):
198+
def test_loss_dtype(loss, dtype_in, dtype_out, sample_weight, out1, out2, n_threads):
209199
"""Test acceptance of dtypes in loss functions.
210200
211201
Check that loss accepts if all input arrays are either all float32 or all
@@ -450,14 +440,10 @@ def test_sample_weight_multiplies_gradients(loss, sample_weight):
450440
rng = np.random.RandomState(42)
451441
sample_weight = rng.normal(size=n_samples).astype(np.float64)
452442

453-
baseline_prediction = loss.fit_intercept_only(
454-
y_true=y_true, sample_weight=None
455-
)
443+
baseline_prediction = loss.fit_intercept_only(y_true=y_true, sample_weight=None)
456444

457445
if loss.n_classes <= 2:
458-
raw_prediction = np.zeros(
459-
shape=(n_samples,), dtype=baseline_prediction.dtype
460-
)
446+
raw_prediction = np.zeros(shape=(n_samples,), dtype=baseline_prediction.dtype)
461447
else:
462448
raw_prediction = np.zeros(
463449
shape=(n_samples, loss.n_classes), dtype=baseline_prediction.dtype
@@ -555,15 +541,19 @@ def test_gradients_hessians_numerically(loss, sample_weight):
555541

556542
def loss_func(x):
557543
return loss.loss(
558-
y_true=y_true, raw_prediction=x, sample_weight=sample_weight,
544+
y_true=y_true,
545+
raw_prediction=x,
546+
sample_weight=sample_weight,
559547
)
560548

561549
g_numeric = numerical_derivative(loss_func, raw_prediction, eps=1e-6)
562550
assert_allclose(g, g_numeric, rtol=5e-6, atol=1e-10)
563551

564552
def grad_func(x):
565553
return loss.gradient(
566-
y_true=y_true, raw_prediction=x, sample_weight=sample_weight,
554+
y_true=y_true,
555+
raw_prediction=x,
556+
sample_weight=sample_weight,
567557
)
568558

569559
h_numeric = numerical_derivative(grad_func, raw_prediction, eps=1e-6)
@@ -588,9 +578,7 @@ def loss_func(x):
588578
sample_weight=sample_weight,
589579
)
590580

591-
g_numeric = numerical_derivative(
592-
loss_func, raw_prediction[:, k], eps=1e-5
593-
)
581+
g_numeric = numerical_derivative(loss_func, raw_prediction[:, k], eps=1e-5)
594582
assert_allclose(g[:, k], g_numeric, rtol=5e-6, atol=1e-10)
595583

596584
def grad_func(x):
@@ -602,9 +590,7 @@ def grad_func(x):
602590
sample_weight=sample_weight,
603591
)[:, k]
604592

605-
h_numeric = numerical_derivative(
606-
grad_func, raw_prediction[:, k], eps=1e-6
607-
)
593+
h_numeric = numerical_derivative(grad_func, raw_prediction[:, k], eps=1e-6)
608594
if loss.approx_hessian:
609595
assert np.all(h >= h_numeric)
610596
else:
@@ -676,9 +662,7 @@ def fprime2(x: np.ndarray) -> np.ndarray:
676662
optimum = optimum.ravel()
677663
assert_allclose(loss.inverse(optimum), y_true)
678664
assert_allclose(func(optimum), 0, atol=1e-14)
679-
assert_allclose(
680-
loss.gradient(y_true=y_true, raw_prediction=optimum), 0, atol=5e-7
681-
)
665+
assert_allclose(loss.gradient(y_true=y_true, raw_prediction=optimum), 0, atol=5e-7)
682666

683667

684668
@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
@@ -740,7 +724,7 @@ def fun(x):
740724
method="SLSQP",
741725
constraints={
742726
"type": "eq",
743-
"fun": lambda x: np.ones((1, loss.n_classes)) @ x
727+
"fun": lambda x: np.ones((1, loss.n_classes)) @ x,
744728
},
745729
)
746730
grad = loss.gradient(
@@ -784,9 +768,7 @@ def test_specific_fit_intercept_only(loss, func, random_dist):
784768
assert baseline_prediction == approx(loss.link(func(y_train)))
785769
assert loss.inverse(baseline_prediction) == approx(func(y_train))
786770
if isinstance(loss, IdentityLink):
787-
assert_allclose(
788-
loss.inverse(baseline_prediction), baseline_prediction
789-
)
771+
assert_allclose(loss.inverse(baseline_prediction), baseline_prediction)
790772

791773
# Test baseline at boundary
792774
if loss.interval_y_true.low_inclusive:
@@ -835,5 +817,5 @@ def test_binary_and_categorical_crossentropy():
835817
raw_cce[:, 1] = 0.5 * raw_prediction
836818
assert_allclose(
837819
bce.loss(y_true=y_train, raw_prediction=raw_prediction),
838-
cce.loss(y_true=y_train, raw_prediction=raw_cce)
820+
cce.loss(y_true=y_train, raw_prediction=raw_cce),
839821
)

0 commit comments

Comments
 (0)
0