From de0b93a87eb4c12a59244ea48859a25dcd9b823f Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 27 Nov 2024 17:39:22 +0100 Subject: [PATCH] MNT add __reduce__ to loss objects --- sklearn/_loss/_loss.pyx.tp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/_loss/_loss.pyx.tp b/sklearn/_loss/_loss.pyx.tp index 56d3aebb6c6f1..6054d4c9472ca 100644 --- a/sklearn/_loss/_loss.pyx.tp +++ b/sklearn/_loss/_loss.pyx.tp @@ -818,6 +818,9 @@ cdef inline double_pair cgrad_hess_exponential( cdef class CyLossFunction: """Base class for convex loss functions.""" + def __reduce__(self): + return (self.__class__, ()) + cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil: """Compute the loss for a single sample. @@ -1013,6 +1016,11 @@ cdef class {{name}}(CyLossFunction): self.{{param}} = {{param}} {{endif}} + {{if param is not None}} + def __reduce__(self): + return (self.__class__, (self.{{param}},)) + {{endif}} + cdef inline double cy_loss(self, double y_true, double raw_prediction) noexcept nogil: return {{closs}}(y_true, raw_prediction{{with_param}})