@@ -29,14 +29,11 @@ dtypes = [('64', 'double', 'np.float64'),
29
29
import numpy as np
30
30
from libc.math cimport exp, fabs, isfinite, log
31
31
from libc.time cimport time, time_t
32
+ from libc.stdio cimport printf
32
33
33
- from ._sgd_fast cimport LossFunction
34
- from ._sgd_fast cimport Log, SquaredLoss
35
-
34
+ from .._loss._loss cimport CyLossFunction, CyHalfSquaredError, CyHalfBinomialLoss
36
35
from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64
37
36
38
- from libc.stdio cimport printf
39
-
40
37
41
38
{{for name_suffix, c_type, np_type in dtypes}}
42
39
@@ -77,7 +74,7 @@ cdef {{c_type}} _logsumexp{{name_suffix}}({{c_type}}* arr, int n_classes) noexce
77
74
{{for name_suffix, c_type, np_type in dtypes}}
78
75
79
76
cdef class MultinomialLogLoss{{nam
8000
e_suffix}}:
80
- cdef {{c_type}} _loss (self, {{c_type}} y, {{c_type}}* prediction, int n_classes,
77
+ cdef {{c_type}} cy_loss (self, {{c_type}} y, {{c_type}}* prediction, int n_classes,
81
78
{{c_type}} sample_weight) noexcept nogil:
82
79
r"""Multinomial Logistic regression loss.
83
80
@@ -121,7 +118,7 @@ cdef class MultinomialLogLoss{{name_suffix}}:
121
118
loss = (logsumexp_prediction - prediction[int(y)]) * sample_weight
122
119
return loss
123
120
124
- cdef void dloss (self, {{c_type}} y, {{c_type}}* prediction, int n_classes,
121
+ cdef void cy_gradient (self, {{c_type}} y, {{c_type}}* prediction, int n_classes,
125
122
{{c_type}} sample_weight, {{c_type}}* gradient_ptr) noexcept nogil:
126
123
r"""Multinomial Logistic regression gradient of the loss.
127
124
@@ -331,7 +328,7 @@ def sag{{name_suffix}}(
331
328
cdef bint prox = beta > 0 and saga
332
329
333
330
# Loss function to optimize
334
- cdef LossFunction loss
331
+ cdef CyLossFunction loss
335
332
# Whether the loss function is multinomial
336
333
cdef bint multinomial = False
337
334
# Multinomial loss function
@@ -341,9 +338,9 @@ def sag{{name_suffix}}(
341
338
multinomial = True
342
339
multiloss = MultinomialLogLoss{{name_suffix}}()
343
340
elif loss_function == "log":
344
- loss = Log ()
341
+ loss = CyHalfBinomialLoss ()
345
342
elif loss_function == "squared":
346
- loss = SquaredLoss ()
343
+ loss = CyHalfSquaredError ()
347
344
else:
348
345
raise ValueError("Invalid loss parameter: got %s instead of "
349
346
"one of ('log', 'squared', 'multinomial')"
@@ -406,9 +403,9 @@ def sag{{name_suffix}}(
406
403
407
404
# compute the gradient for this sample, given the prediction
408
405
if multinomial:
409
- multiloss.dloss (y, &prediction[0], n_classes, sample_weight, &gradient[0])
406
+ multiloss.cy_gradient (y, &prediction[0], n_classes, sample_weight, &gradient[0])
410
407
else:
411
- gradient[0] = loss.dloss (y, prediction[0]) * sample_weight
408
+ gradient[0] = loss.cy_gradient (y, prediction[0]) * sample_weight
412
409
413
410
# L2 regularization by simply rescaling the weights
414
411
wscale *= wscale_update
@@ -539,7 +536,7 @@ def sag{{name_suffix}}(
539
536
(n_iter + 1, end_time - start_time))
540
537
break
541
538
elif verbose:
542
- printf('Epoch %d, change: %.8f \n', n_iter + 1,
539
+ printf('Epoch %d, change: %.8g \n', n_iter + 1,
543
540
max_change / max_weight)
544
541
n_iter += 1
545
542
# We do the error treatment here based on error code in status to avoid
@@ -827,10 +824,10 @@ def _multinomial_grad_loss_all_samples(
827
824
)
828
825
829
826
# compute the gradient for this sample, given the prediction
830
- multiloss.dloss (y, &prediction[0], n_classes, sample_weight, &gradient[0])
827
+ multiloss.cy_gradient (y, &prediction[0], n_classes, sample_weight, &gradient[0])
831
828
832
829
# compute the loss for this sample, given the prediction
833
- sum_loss += multiloss._loss (y, &prediction[0], n_classes, sample_weight)
830
+ sum_loss += multiloss.cy_loss (y, &prediction[0], n_classes, sample_weight)
834
831
835
832
# update the sum of the gradient
836
833
for j in range(xnnz):
0 commit comments