8000 ENH Improves memory usage for standard scalar (#20652) · rth/scikit-learn@f812e2a · GitHub
[go: up one dir, main page]

Skip to content

Commit f812e2a

Browse files
authored
ENH Improves memory usage for standard scalar (scikit-learn#20652)
1 parent efede4b commit f812e2a

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

doc/whats_new/v1.0.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,9 @@ Changelog
684684
`n_features_in_` and will be removed in 1.2. :pr:`20240` by
685685
:user:`Jérémie du Boisberranger <jeremiedbb>`.
686686

687+
- |Efficiency| `preprocessing.StandardScaler` is faster and more memory
688+
efficient. :pr:`20652` by `Thomas Fan`_.
689+
687690
:mod:`sklearn.tree`
688691
...................
689692

sklearn/utils/extmath.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -955,24 +955,30 @@ def _incremental_mean_and_var(
955955
# new = the current increment
956956
# updated = the aggregated stats
957957
last_sum = last_mean * last_sample_count
958+
X_nan_mask = np.isnan(X)
959+
if np.any(X_nan_mask 8000 ):
960+
sum_op = np.nansum
961+
else:
962+
sum_op = np.sum
958963
if sample_weight is not None:
959964
if np_version >= parse_version("1.16.6"):
960965
# equivalent to np.nansum(X * sample_weight, axis=0)
961966
# safer because np.float64(X*W) != np.float64(X)*np.float64(W)
962967
# dtype arg of np.matmul only exists since version 1.16
963968
new_sum = _safe_accumulator_op(
964-
np.matmul, sample_weight, np.where(np.isnan(X), 0, X)
969+
np.matmul, sample_weight, np.where(X_nan_mask, 0, X)
965970
)
966971
else:
967972
new_sum = _safe_accumulator_op(
968973
np.nansum, X * sample_weight[:, None], axis=0
969974
)
970975
new_sample_count = _safe_accumulator_op(
971-
np.sum, sample_weight[:, None] * (~np.isnan(X)), axis=0
976+
np.sum, sample_weight[:, None] * (~X_nan_mask), axis=0
972977
)
973978
else:
974-
new_sum = _safe_accumulator_op(np.nansum, X, axis=0)
975-
new_sample_count = np.sum(~np.isnan(X), axis=0)
979+
new_sum = _safe_accumulator_op(sum_op, X, axis=0)
980+
n_samples = X.shape[0]
981+
new_sample_count = n_samples - np.sum(X_nan_mask, axis=0)
976982

977983
updated_sample_count = last_sample_count + new_sample_count
978984

@@ -982,29 +988,31 @@ def _incremental_mean_and_var(
982988
updated_variance = None
983989
else:
984990
T = new_sum / new_sample_count
991+
temp = X - T
985992
if sample_weight is not None:
986993
if np_version >= parse_version("1.16.6"):
987994
# equivalent to np.nansum((X-T)**2 * sample_weight, axis=0)
988995
# safer because np.float64(X*W) != np.float64(X)*np.float64(W)
989996
# dtype arg of np.matmul only exists since version 1.16
990-
new_unnormalized_variance = _safe_accumulator_op(
991-
np.matmul, sample_weight, np.where(np.isnan(X), 0, (X - T) ** 2)
992-
)
993997
correction = _safe_accumulator_op(
994-
np.matmul, sample_weight, np.where(np.isnan(X), 0, X - T)
998+
np.matmul, sample_weight, np.where(X_nan_mask, 0, temp)
995999
)
996-
else:
1000+
temp **= 2
9971001
new_unnormalized_variance = _safe_accumulator_op(
998-
np.nansum, (X - T) ** 2 * sample_weight[:, None], axis=0
1002+
np.matmul, sample_weight, np.where(X_nan_mask, 0, temp)
9991003
)
1004+
else:
10001005
correction = _safe_accumulator_op(
1001-
np.nansum, (X - T) * sample_weight[:, None], axis=0
1006+
sum_op, temp * sample_weight[:, None], axis=0
1007+
)
1008+
temp *= temp
1009+
new_unnormalized_variance = _safe_accumulator_op(
1010+
sum_op, temp * sample_weight[:, None], axis=0
10021011
)
10031012
else:
1004-
new_unnormalized_variance = _safe_accumulator_op(
1005-
np.nansum, (X - T) ** 2, axis=0
1006-
)
1007-
correction = _safe_accumulator_op(np.nansum, X - T, axis=0)
1013+
correction = _safe_accumulator_op(sum_op, temp, axis=0)
1014+
temp **= 2
1015+
new_unnormalized_variance = _safe_accumulator_op(sum_op, temp, axis=0)
10081016

10091017
# correction term of the corrected 2 pass algorithm.
10101018
# See "Algorithms for computing the sample variance: analysis

0 commit comments

Comments
 (0)
0