|
22 | 22 | from sklearn.utils._testing import assert_warns_message
|
23 | 23 | from sklearn.utils._testing import skip_if_32bit
|
24 | 24 |
|
25 |
| -from sklearn.utils.extmath import density |
| 25 | +from sklearn.utils.extmath import density, _safe_accumulator_op |
26 | 26 | from sklearn.utils.extmath import randomized_svd
|
27 | 27 | from sklearn.utils.extmath import row_norms
|
28 | 28 | from sklearn.utils.extmath import weighted_mode
|
@@ -470,32 +470,70 @@ def test_incremental_weighted_mean_and_variance_simple():
|
470 | 470 |
|
471 | 471 |
|
472 | 472 | def test_incremental_weighted_mean_and_variance():
|
473 |
| - rng = np.random.RandomState(42) |
474 |
| - mult = 10 |
475 |
| - X = rng.rand(1000, 20)*mult |
476 |
| - sample_weight = rng.rand(X.shape[0]) * mult |
477 | 473 |
|
478 |
| - n = X.shape[0] |
479 |
| - last_mean, last_weight_sum, last_var = 0, 0, 0 |
480 |
| - mean_exp = np.average(X, weights=sample_weight, axis=0) |
481 |
| - var_exp = np.average(X ** 2, weights=sample_weight, axis=0) - mean_exp ** 2 |
482 |
| - for chunk_size in [1, 2, 50, n, n + 42]: |
483 |
| - for batch in gen_batches(n, chunk_size): |
484 |
| - last_mean, last_var, last_weight_sum = \ |
485 |
| - _incremental_weighted_mean_and_var(X[batch], |
486 |
| - sample_weight[batch], |
487 |
| - last_mean, |
488 |
| - last_var, |
489 |
| - last_weight_sum) |
490 |
| - assert_almost_equal(last_mean, mean_exp) |
491 |
| - assert_almost_equal(last_var, var_exp) |
| 474 | + # Testing of correctness and numerical stability |
| 475 | + def test(X, sample_weight, mean_exp=None, var_exp=None): |
| 476 | + n = X.shape[0] |
| 477 | + if mean_exp is None: |
| 478 | + mean_exp = \ |
| 479 | + _safe_accumulator_op( |
| 480 | + np.average, X, weights=sample_weight, axis=0) |
| 481 | + if var_exp is None: |
| 482 | + var_exp = \ |
| 483 | + _safe_accumulator_op( |
| 484 | + np.average, (X-mean_exp)**2, weights=sample_weight, axis=0) |
| 485 | + for chunk_size in [1, n//10 + 1, n//4 + 1, n//2 + 1, n]: |
| 486 | + last_mean, last_weight_sum, last_var = 0, 0, 0 |
| 487 | + for batch in gen_batches(n, chunk_size): |
| 488 | + last_mean, last_var, last_weight_sum = \ |
| 489 | + _incremental_weighted_mean_and_var(X[batch], |
| 490 | + sample_weight[batch], |
| 491 | + last_mean, |
| 492 | + last_var, |
| 493 | + last_weight_sum) |
| 494 | + assert_almost_equal(last_mean, mean_exp) |
| 495 | + assert_almost_equal(last_var, var_exp, 6) |
| 496 | + |
| 497 | + HIGH_MEAN = 10e6 |
| 498 | + LOW_VAR = 10e-7 |
| 499 | + NORMAL_MEAN = 0.0 |
| 500 | + NORMAL_VAR = 1.0 |
| 501 | + SIZE = (100, 20) |
| 502 | + |
| 503 | + rng = np.random.RandomState(42) |
| 504 | + NORMAL_WEIGHT = \ |
| 505 | + rng.normal(loc=NORMAL_MEAN, scale=NORMAL_VAR, size=(SIZE[0],)) |
| 506 | + ALMOST_ZERO_WEIGHT = \ |
| 507 | + rng.normal(loc=NORMAL_MEAN, scale=LOW_VAR, size=(SIZE[0],)) |
| 508 | + ALMOST_ONES_WEIGHT = rng.normal(loc=1.0, scale=LOW_VAR, size=(SIZE[0],)) |
| 509 | + JUST_WEIGHT = rng.normal(loc=10.0, scale=NORMAL_VAR, size=(SIZE[0],)) |
| 510 | + ONES_WEIGHT = np.ones(SIZE[0]) |
| 511 | + |
| 512 | + means = [NORMAL_MEAN, HIGH_MEAN] |
| 513 | + vars = [NORMAL_VAR, LOW_VAR] |
| 514 | + weights = \ |
| 515 | + [NORMAL_WEIGHT, ALMOST_ONES_WEIGHT, JUST_WEIGHT, ALMOST_ZERO_WEIGHT] |
| 516 | + means_vars = ((m, v) for m in means for v in vars) |
| 517 | + |
| 518 | + # Comparing with weighted np.average |
| 519 | + for mean, var in means_vars: |
| 520 | + X = rng.normal(loc=mean, scale=var, size=SIZE) |
| 521 | + print(mean, var) |
| 522 | + for weight in weights: |
| 523 | + test(X, weight) |
| 524 | + |
| 525 | + # Comparing with unweighted np.average |
| 526 | + for mean, var in means_vars: |
| 527 | + X = rng.normal(loc=mean, scale=var, size=SIZE) |
| 528 | + mean_exp = _safe_accumulator_op(np.mean, X, axis=0) |
| 529 | + var_exp = _safe_accumulator_op(np.var, X, axis=0) |
| 530 | + test(X, ONES_WEIGHT, mean_exp, var_exp) |
492 | 531 |
|
493 | 532 |
|
494 | 533 | def test_incremental_weighted_mean_and_variance_ignore_nan():
|
495 | 534 | old_means = np.array([535., 535., 535., 535.])
|
496 | 535 | old_variances = np.array([4225., 4225., 4225., 4225.])
|
497 | 536 | old_weight_sum = np.array([2, 2, 2, 2], dtype=np.int32)
|
498 |
| - |
499 | 537 | sample_weights_X = np.ones(3)
|
500 | 538 | sample_weights_X_nan = np.ones(4)
|
501 | 539 |
|
|
0 commit comments