From 972a2b2d182ca6322037f9eda4ec8c7e8797f019 Mon Sep 17 00:00:00 2001 From: Joan Massich Date: Fri, 1 Mar 2019 11:10:39 +0100 Subject: [PATCH 1/5] TST: Add dtype match and stability --- sklearn/linear_model/tests/test_sgd.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index 405f7003798ff..9756956a31f89 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -7,6 +7,7 @@ from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_less @@ -1673,3 +1674,23 @@ def test_SGDClassifier_fit_for_all_backends(backend): with parallel_backend(backend=backend): clf_parallel.fit(X, y) assert_array_almost_equal(clf_sequential.coef_, clf_parallel.coef_) + + +@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, + SGDRegressor, SparseSGDRegressor]) +@pytest.mark.parametrize( + 'loss', ['hinge', 'squared_hinge', 'log', 'modified_huber']) +def test_dtype_sgd_match_and_stability(klass, loss): + # rtol = 1e-2 if os.name == 'nt' and _IS_32BIT else 1e-5 + rtol = 1e-5 + clf_dict = dict() + for current_dtype in (np.float32, np.float64): + clf_dict[current_dtype] = (klass(alpha=0.01) + .fit(X=X.astype(current_dtype, copy=False), + y=np.array(Y, dtype=current_dtype))) + + assert clf_dict[np.float32].coef_.dtype == np.float32 + assert clf_dict[np.float64].coef_.dtype == np.float64 + assert_allclose(clf_dict[np.float32].coef_, + clf_dict[np.float64].coef_, + rtol=rtol) From df81450e7d988a7b39e6b15f5075acf97b51e5fc Mon Sep 17 00:00:00 2001 From: Joan Massich Date: Fri, 1 Mar 2019 11:14:20 +0100 Subject: [PATCH 2/5] update whatsnew --- doc/whats_new/v0.21.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index f185491ded469..e19551891ce28 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -187,6 +187,10 @@ Support for Python 3.4 and below has been officially dropped. :mod:`sklearn.linear_model` ........................... +- |Enhancement| :mod:`linear_model.stochastic_gradient` now preserves + ``float32`` and ``float64`` dtypes. :issues:`11000` by + :user:`Joan Massich ` + - |Feature| :class:`linear_model.LogisticRegression` and :class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty, with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug `. From a4af83b014d397546ff8b33c74c428cbfe3f45f2 Mon Sep 17 00:00:00 2001 From: Joan Massich Date: Fri, 1 Mar 2019 12:00:52 +0100 Subject: [PATCH 3/5] change all np.float64 for 64, 32 list or X.dtype --- sklearn/linear_model/stochastic_gradient.py | 46 ++++++++++++--------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index d094cd0988853..4189ec1860d24 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -223,7 +223,7 @@ def _validate_sample_weight(self, sample_weight, n_samples): return sample_weight def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None, - intercept_init=None): + intercept_init=None, dtype=np.float64): """Allocate mem for parameters; initialize if provided.""" if n_classes > 2: # allocate coef_ for multi-class @@ -235,7 +235,7 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None, self.coef_ = coef_init else: self.coef_ = np.zeros((n_classes, n_features), - dtype=np.float64, order="C") + dtype=dtype, order="C") # allocate intercept_ for multi-class if intercept_init is not None: @@ -245,12 +245,12 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None, "does not match dataset.") self.intercept_ = intercept_init else: - self.intercept_ = np.zeros(n_classes, dtype=np.float64, + self.intercept_ = np.zeros(n_classes, dtype=dtype, order="C") else: # allocate coef_ for binary problem if coef_init is not None: - coef_init = np.asarray(coef_init, dtype=np.float64, + coef_init = np.asarray(coef_init, dtype=dtype, order="C") coef_init = coef_init.ravel() if coef_init.shape != (n_features,): @@ -259,28 +259,28 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None, self.coef_ = coef_init else: self.coef_ = np.zeros(n_features, - dtype=np.float64, + dtype=dtype, order="C") # allocate intercept_ for binary problem if intercept_init is not None: - intercept_init = np.asarray(intercept_init, dtype=np.float64) + intercept_init = np.asarray(intercept_init, dtype=dtype) if intercept_init.shape != (1,) and intercept_init.shape != (): raise ValueError("Provided intercept_init " "does not match dataset.") self.intercept_ = intercept_init.reshape(1,) else: - self.intercept_ = np.zeros(1, dtype=np.float64, order="C") + self.intercept_ = np.zeros(1, dtype=dtype, order="C") # initialize average parameters if self.average > 0: self.standard_coef_ = self.coef_ self.standard_intercept_ = self.intercept_ self.average_coef_ = np.zeros(self.coef_.shape, - dtype=np.float64, + dtype=dtype, order="C") self.average_intercept_ = np.zeros(self.standard_intercept_.shape, - dtype=np.float64, + dtype=dtype, order="C") def _make_validation_split(self, y): @@ -331,12 +331,12 @@ def _make_validation_score_cb(self, validation_mask, X, y, sample_weight, sample_weight[validation_mask], classes=classes) -def _prepare_fit_binary(est, y, i): +def _prepare_fit_binary(est, y, i, dtype=np.float64): """Initialization for fit_binary. Returns y, coef, intercept, average_coef, average_intercept. """ - y_i = np.ones(y.shape, dtype=np.float64, order="C") + y_i = np.ones(y.shape, dtype=dtype, order="C") y_i[y != est.classes_[i]] = -1.0 average_intercept = 0 average_coef = None @@ -412,7 +412,7 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter, # if average is not true, average_coef, and average_intercept will be # unused y_i, coef, intercept, average_coef, average_intercept = \ - _prepare_fit_binary(est, y, i) + _prepare_fit_binary(est, y, i, dtype=X.dtype) assert y_i.shape[0] == y.shape[0] == sample_weight.shape[0] dataset, intercept_decay = make_dataset(X, y_i, sample_weight) @@ -515,7 +515,10 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate, max_iter, classes, sample_weight, coef_init, intercept_init): - X, y = check_X_y(X, y, 'csr', dtype=np.float64, order="C", + X, y = check_X_y(X, y, + accept_sparse='csr', + dtype=[np.float64, np.float32], + order="C", accept_large_sparse=False) n_samples, n_features = X.shape @@ -531,7 +534,8 @@ def _partial_fit(self, X, y, alpha, C, if getattr(self, "coef_", None) is None or coef_init is not None: self._allocate_parameter_mem(n_classes, n_features, - coef_init, intercept_init) + coef_init, intercept_init, + dtype=X.dtype) elif n_features != self.coef_.shape[-1]: raise ValueError("Number of features %d does not match previous " "data %d." % (n_features, self.coef_.shape[-1])) @@ -564,7 +568,10 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None, if hasattr(self, "classes_"): self.classes_ = None - X, y = check_X_y(X, y, 'csr', dtype=np.float64, order="C", + X, y = check_X_y(X, y, + accept_sparse='csr', + dtype=[np.float64, np.float32], + order="C", accept_large_sparse=False) # labels can be encoded as float, int, or string literals @@ -1136,9 +1143,10 @@ def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001, def _partial_fit(self, X, y, alpha, C, loss, learning_rate, max_iter, sample_weight, coef_init, intercept_init): - X, y = check_X_y(X, y, "csr", copy=False, order='C', dtype=np.float64, + X, y = check_X_y(X, y, "csr", copy=False, order='C', + dtype=[np.float64, np.float32], accept_large_sparse=False) - y = y.astype(np.float64, copy=False) + y = y.astype(X.dtype, copy=False) # XXX: isn't this done in check_X_y already n_samples, n_features = X.shape @@ -1153,9 +1161,9 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate, "data %d." % (n_features, self.coef_.shape[-1])) if self.average > 0 and getattr(self, "average_coef_", None) is None: self.average_coef_ = np.zeros(n_features, - dtype=np.float64, + dtype=X.dtype, order="C") - self.average_intercept_ = np.zeros(1, dtype=np.float64, order="C") + self.average_intercept_ = np.zeros(1, dtype=X.dtype, order="C") self._fit_regressor(X, y, alpha, C, loss, learning_rate, sample_weight, max_iter) From faaed0b0f6fba10cf464b683c2a719cb845954a9 Mon Sep 17 00:00:00 2001 From: Joan Massich Date: Fri, 1 Mar 2019 14:01:35 +0100 Subject: [PATCH 4/5] wip --- sklearn/linear_model/base.py | 2 ++ sklearn/linear_model/sgd_fast.pyx | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/linear_model/base.py b/sklearn/linear_model/base.py index 54083fee1e904..7c4435da35625 100644 --- a/sklearn/linear_model/base.py +++ b/sklearn/linear_model/base.py @@ -87,6 +87,8 @@ def make_dataset(X, y, sample_weight, random_state=None): CSRData = CSRDataset64 ArrayData = ArrayDataset64 + sample_weight = sample_weight.astype(X.dtype, copy=False) # XXX: I don't think this should be here + if sp.issparse(X): dataset = CSRData(X.data, X.indptr, X.indices, y, sample_weight, seed=seed) diff --git a/sklearn/linear_model/sgd_fast.pyx b/sklearn/linear_model/sgd_fast.pyx index ddea4b9710501..3917b34b4b61d 100644 --- a/sklearn/linear_model/sgd_fast.pyx +++ b/sklearn/linear_model/sgd_fast.pyx @@ -23,7 +23,8 @@ cdef extern from "sgd_fast_helpers.h": bint skl_isfinite(double) nogil from sklearn.utils.weight_vector cimport WeightVector -from sklearn.utils.seq_dataset cimport SequentialDataset64 as SequentialDataset +from sklearn.utils.seq_dataset cimport SequentialDataset32 as SequentialDataset +# from sklearn.utils.seq_dataset cimport SequentialDataset32 as SequentialDataset np.import_array() From 02c8f0ab77916128bca3b3c829b478d2bc118c34 Mon Sep 17 00:00:00 2001 From: Joan Massich Date: Fri, 1 Mar 2019 16:02:57 +0100 Subject: [PATCH 5/5] start using fuse types (its not finish) --- sklearn/linear_model/sgd_fast.pyx | 45 ++++++++++++++++--------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/sklearn/linear_model/sgd_fast.pyx b/sklearn/linear_model/sgd_fast.pyx index 3917b34b4b61d..2999c968feb02 100644 --- a/sklearn/linear_model/sgd_fast.pyx +++ b/sklearn/linear_model/sgd_fast.pyx @@ -16,6 +16,7 @@ import sys from time import time cimport cython +from cython cimport floating from libc.math cimport exp, log, sqrt, pow, fabs cimport numpy as np from numpy.math cimport INFINITY @@ -23,8 +24,8 @@ cdef extern from "sgd_fast_helpers.h": bint skl_isfinite(double) nogil from sklearn.utils.weight_vector cimport WeightVector +# from sklearn.utils.seq_dataset cimport SequentialDataset64 as SequentialDataset from sklearn.utils.seq_dataset cimport SequentialDataset32 as SequentialDataset -# from sklearn.utils.seq_dataset cimport SequentialDataset32 as SequentialDataset np.import_array() @@ -334,7 +335,7 @@ cdef class SquaredEpsilonInsensitive(Regression): return SquaredEpsilonInsensitive, (self.epsilon,) -def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, +def plain_sgd(floating[::1] weights, double intercept, LossFunction loss, int penalty_type, @@ -451,9 +452,9 @@ def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, return standard_weights, standard_intercept, n_iter_ -def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights, +def average_sgd(floating[::1] weights, double intercept, - np.ndarray[double, ndim=1, mode='c'] average_weights, + floating[::1] average_weights, double average_intercept, LossFunction loss, int penalty_type, @@ -580,9 +581,9 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights, average) -def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, +def _plain_sgd(floating[::1] weights, double intercept, - np.ndarray[double, ndim=1, mode='c'] average_weights, + floating[::1] average_weights, double average_intercept, LossFunction loss, int penalty_type, @@ -606,8 +607,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, cdef Py_ssize_t n_features = weights.shape[0] cdef WeightVector w = WeightVector(weights, average_weights) - cdef double* w_ptr = &weights[0] - cdef double *x_data_ptr = NULL + cdef floating* w_ptr = &weights[0] + cdef floating *x_data_ptr = NULL cdef int *x_ind_ptr = NULL cdef double* ps_ptr = NULL @@ -622,8 +623,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, cdef double score = 0.0 cdef double best_loss = INFINITY cdef double best_score = -INFINITY - cdef double y = 0.0 - cdef double sample_weight + cdef floating y = 0.0 + cdef floating sample_weight cdef double class_weight = 1.0 cdef unsigned int count = 0 cdef unsigned int epoch = 0 @@ -639,11 +640,10 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, cdef unsigned char [:] validation_mask_view = validation_mask # q vector is only used for L1 regularization - cdef np.ndarray[double, ndim = 1, mode = "c"] q = None - cdef double * q_data_ptr = NULL - if penalty_type == L1 or penalty_type == ELASTICNET: - q = np.zeros((n_features,), dtype=np.float64, order="c") - q_data_ptr = q.data + cdef floating[::1] q = np.zeros((n_features,), dtype=weights.dtype, + order="c") + cdef floating * q_data_ptr = &q[0] + cdef double u = 0.0 if penalty_type == L2: @@ -758,7 +758,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, # floating-point under-/overflow check. if (not skl_isfinite(intercept) - or any_nonfinite(weights.data, n_features)): + or any_nonfinite(&weights[0], n_features)): infinity = True break @@ -803,7 +803,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, return weights, intercept, average_weights, average_intercept, epoch + 1 -cdef bint any_nonfinite(double *w, int n) nogil: +cdef bint any_nonfinite(floating *w, int n) nogil: + cdef int i for i in range(n): if not skl_isfinite(w[i]): return True @@ -820,18 +821,18 @@ cdef double sqnorm(double * x_data_ptr, int * x_ind_ptr, int xnnz) nogil: return x_norm -cdef void l1penalty(WeightVector w, double * q_data_ptr, - int *x_ind_ptr, int xnnz, double u) nogil: +cdef void l1penalty(WeightVector w, floating * q_data_ptr, + int *x_ind_ptr, int xnnz, floating u) nogil: """Apply the L1 penalty to each updated feature This implements the truncated gradient approach by [Tsuruoka, Y., Tsujii, J., and Ananiadou, S., 2009]. """ - cdef double z = 0.0 + cdef floating z = 0.0 cdef int j = 0 cdef int idx = 0 - cdef double wscale = w.wscale - cdef double *w_data_ptr = w.w_data_ptr + cdef floating wscale = w.wscale + cdef floating *w_data_ptr = w.w_data_ptr for j in range(xnnz): idx = x_ind_ptr[j] z = w_data_ptr[idx]