From e612e68637080293db41f02f67fb63ae76f228a6 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 12 May 2020 17:13:45 +0200 Subject: [PATCH 1/2] MNT cleaner cdef loss function in sgd --- sklearn/linear_model/_sag_fast.pyx.tp | 8 +++---- sklearn/linear_model/_sgd_fast.pxd | 10 ++++----- sklearn/linear_model/_sgd_fast.pyx | 31 +++++++++++++------------- sklearn/linear_model/tests/test_sgd.py | 6 ++--- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 141890497fcd2..8508340e3b329 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -152,7 +152,7 @@ cdef class MultinomialLogLoss{{name}}: loss = (logsumexp_prediction - prediction[int(y)]) * sample_weight return loss - cdef void _dloss(self, {{c_type}}* prediction, {{c_type}} y, int n_classes, + cdef void dloss(self, {{c_type}}* prediction, {{c_type}} y, int n_classes, {{c_type}} sample_weight, {{c_type}}* gradient_ptr) nogil: r"""Multinomial Logistic regression gradient of the loss. @@ -419,10 +419,10 @@ def sag{{name}}(SequentialDataset{{name}} dataset, # compute the gradient for this sample, given the prediction if multinomial: - multiloss._dloss(prediction, y, n_classes, sample_weight, + multiloss.dloss(prediction, y, n_classes, sample_weight, gradient) else: - gradient[0] = loss._dloss(prediction[0], y) * sample_weight + gradient[0] = loss.dloss(prediction[0], y) * sample_weight # L2 regularization by simply rescaling the weights wscale *= wscale_update @@ -783,7 +783,7 @@ def _multinomial_grad_loss_all_samples( intercept, prediction, n_classes) # compute the gradient for this sample, given the prediction - multiloss._dloss(prediction, y, n_classes, sample_weight, gradient) + multiloss.dloss(prediction, y, n_classes, sample_weight, gradient) # compute the loss for this sample, given the prediction sum_loss += multiloss._loss(prediction, y, n_classes, sample_weight) diff --git a/sklearn/linear_model/_sgd_fast.pxd b/sklearn/linear_model/_sgd_fast.pxd index 53062097156b7..3c02f5ab1a834 100644 --- a/sklearn/linear_model/_sgd_fast.pxd +++ b/sklearn/linear_model/_sgd_fast.pxd @@ -3,24 +3,24 @@ cdef class LossFunction: cdef double loss(self, double p, double y) nogil - cdef double _dloss(self, double p, double y) nogil + cdef double dloss(self, double p, double y) nogil cdef class Regression(LossFunction): cdef double loss(self, double p, double y) nogil - cdef double _dloss(self, double p, double y) nogil + cdef double dloss(self, double p, double y) nogil cdef class Classification(LossFunction): cdef double loss(self, double p, double y) nogil - cdef double _dloss(self, double p, double y) nogil + cdef double dloss(self, double p, double y) nogil cdef class Log(Classification): cdef double loss(self, double p, double y) nogil - cdef double _dloss(self, double p, double y) nogil + cdef double dloss(self, double p, double y) nogil cdef class SquaredLoss(Regression): cdef double loss(self, double p, double y) nogil - cdef double _dloss(self, double p, double y) nogil + cdef double dloss(self, double p, double y) nogil diff --git a/sklearn/linear_model/_sgd_fast.pyx b/sklearn/linear_model/_sgd_fast.pyx index cc34400dbcfef..d8f3a2b3056f0 100644 --- a/sklearn/linear_model/_sgd_fast.pyx +++ b/sklearn/linear_model/_sgd_fast.pyx @@ -66,7 +66,11 @@ cdef class LossFunction: """ return 0. - def dloss(self, double p, double y): + def py_dloss(self, double p, double y): + """Python version of derivative for testing.""" + return self.dloss(p, y) + + cdef double dloss(self, double p, double y) nogil: """Evaluate the derivative of the loss function with respect to the prediction `p`. @@ -81,9 +85,6 @@ cdef class LossFunction: double The derivative of the loss function with regards to `p`. """ - return self._dloss(p, y) - - cdef double _dloss(self, double p, double y) nogil: # Implementation of dloss; separate function because cpdef and nogil # can't be combined. return 0. @@ -95,7 +96,7 @@ cdef class Regression(LossFunction): cdef double loss(self, double p, double y) nogil: return 0. - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: return 0. @@ -105,7 +106,7 @@ cdef class Classification(LossFunction): cdef double loss(self, double p, double y) nogil: return 0. - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: return 0. @@ -126,7 +127,7 @@ cdef class ModifiedHuber(Classification): else: return -4.0 * z - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: cdef double z = p * y if z >= 1.0: return 0.0 @@ -161,7 +162,7 @@ cdef class Hinge(Classification): return self.threshold - z return 0.0 - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: cdef double z = p * y if z <= self.threshold: return -y @@ -193,7 +194,7 @@ cdef class SquaredHinge(Classification): return z * z return 0.0 - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: cdef double z = self.threshold - p * y if z > 0: return -2 * y * z @@ -215,7 +216,7 @@ cdef class Log(Classification): return -z return log(1.0 + exp(-z)) - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: cdef double z = p * y # approximately equal and saves the computation of the log if z > 18.0: @@ -233,7 +234,7 @@ cdef class SquaredLoss(Regression): cdef double loss(self, double p, double y) nogil: return 0.5 * (p - y) * (p - y) - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: return p - y def __reduce__(self): @@ -262,7 +263,7 @@ cdef class Huber(Regression): else: return self.c * abs_r - (0.5 * self.c * self.c) - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: cdef double r = p - y cdef double abs_r = fabs(r) if abs_r <= self.c: @@ -291,7 +292,7 @@ cdef class EpsilonInsensitive(Regression): cdef double ret = fabs(y - p) - self.epsilon return ret if ret > 0 else 0 - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: if y - p > self.epsilon: return -1 elif p - y > self.epsilon: @@ -318,7 +319,7 @@ cdef class SquaredEpsilonInsensitive(Regression): cdef double ret = fabs(y - p) - self.epsilon return ret * ret if ret > 0 else 0 - cdef double _dloss(self, double p, double y) nogil: + cdef double dloss(self, double p, double y) nogil: cdef double z z = y - p if z > self.epsilon: @@ -542,7 +543,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, update = sqnorm(x_data_ptr, x_ind_ptr, xnnz) update = loss.loss(p, y) / (update + 0.5 / C) else: - dloss = loss._dloss(p, y) + dloss = loss.dloss(p, y) # clip dloss with large values to avoid numerical # instabilities if dloss < -MAX_DLOSS: diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index 22744a427b901..1e37a118caa93 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -1438,7 +1438,7 @@ def _test_gradient_common(loss_function, cases): # Test gradient of different loss functions # cases is a list of (p, y, expected) for p, y, expected in cases: - assert_almost_equal(loss_function.dloss(p, y), expected) + assert_almost_equal(loss_function.py_dloss(p, y), expected) def test_gradient_hinge(): @@ -1488,8 +1488,8 @@ def test_gradient_log(): (17.9, -1.0, 1.0), (-17.9, 1.0, -1.0), ] _test_gradient_common(loss, cases) - assert_almost_equal(loss.dloss(18.1, 1.0), np.exp(-18.1) * -1.0, 16) - assert_almost_equal(loss.dloss(-18.1, -1.0), np.exp(-18.1) * 1.0, 16) + assert_almost_equal(loss.py_dloss(18.1, 1.0), np.exp(-18.1) * -1.0, 16) + assert_almost_equal(loss.py_dloss(-18.1, -1.0), np.exp(-18.1) * 1.0, 16) def test_gradient_squared_loss(): From 6e99dbdc6ff2b18ef4ff5a3b05a6c7de09a7a3c7 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 12 May 2020 17:22:00 +0200 Subject: [PATCH 2/2] MNT clearify comments for cdef and nogil --- sklearn/linear_model/_sgd_fast.pyx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/_sgd_fast.pyx b/sklearn/linear_model/_sgd_fast.pyx index d8f3a2b3056f0..ab1a274d37c8f 100644 --- a/sklearn/linear_model/_sgd_fast.pyx +++ b/sklearn/linear_model/_sgd_fast.pyx @@ -67,7 +67,10 @@ cdef class LossFunction: return 0. def py_dloss(self, double p, double y): - """Python version of derivative for testing.""" + """Python version of `dloss` for testing. + + Pytest needs a python function and can't use cdef functions. + """ return self.dloss(p, y) cdef double dloss(self, double p, double y) nogil: @@ -85,8 +88,6 @@ cdef class LossFunction: double The derivative of the loss function with regards to `p`. """ - # Implementation of dloss; separate function because cpdef and nogil - # can't be combined. return 0.