8000 MNT add __reduce__ to loss objects (#30356) · scikit-learn/scikit-learn@26384be · GitHub
[go: up one dir, main page]

Skip to content

Commit 26384be

Browse files
adrinjalalijeremiedbb
authored andcommitted
MNT add __reduce__ to loss objects (#30356)
1 parent 5f7d66c commit 26384be

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

sklearn/_loss/_loss.pyx.tp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,9 @@ cdef inline double_pair cgrad_hess_exponential(
818818
cdef class CyLossFunction:
819819
"""Base class for convex loss functions."""
820820

821+
def __reduce__(self):
822+
return (self.__class__, ())
823+
821824
cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil:
822825
"""Compute the loss for a single sample.
823826

@@ -1013,6 +1016,11 @@ cdef class {{name}}(CyLossFunction):
10131016
self.{{param}} = {{param}}
10141017
{{endif}}
10151018

1019+
{{if param is not None}}
1020+
def __reduce__(self):
1021+
return (self.__class__, (self.{{param}},))
1022+
{{endif}}
1023+
10161024
cdef inline double cy_loss(self, double y_true, double raw_prediction) noexcept nogil:
10171025
return {{closs}}(y_true, raw_prediction{{with_param}})
10181026

0 commit comments

Comments
 (0)
0