8000 MISC reinserted class_weight as fit parameter, added deprecation warn… · scikit-learn/scikit-learn@24ea7e7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 24ea7e7

Browse files
committed
MISC reinserted class_weight as fit parameter, added deprecation warning.
1 parent 36fa40a commit 24ea7e7

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

sklearn/svm/base.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(self, impl, kernel, degree, gamma, coef0,
105105
self.sparse = sparse
106106
self.class_weight = class_weight
107107

108-
def fit(self, X, y, sample_weight=None):
108+
def fit(self, X, y, class_weight=None, sample_weight=None):
109109
"""Fit the SVM model according to the given training data.
110110
111111
Parameters
@@ -135,6 +135,11 @@ def fit(self, X, y, sample_weight=None):
135135
matrices as input.
136136
"""
137137
self._sparse = sp.isspmatrix(X) if self.sparse == "auto" else self.sparse
138+
if class_weight != None:
139+
warnings.warn("'class_weight' is now an initialization parameter."
140+
"Using it in the 'fit' method is deprecated.",
141+
DeprecationWarning)
142+
self.class_weight = class_weight
138143
fit = self._sparse_fit if self._sparse else self._dense_fit
139144
fit(X, y, sample_weight)
140145
return self
@@ -595,7 +600,7 @@ def _get_solver_type(self):
595600
+ error_string)
596601
return self._solver_type_dict[solver_type]
597602

598-
def fit(self, X, y):
603+
def fit(self, X, y, class_weight=None):
599604
"""Fit the model according to the given training data.
600605
601606
Parameters
@@ -617,6 +622,12 @@ def fit(self, X, y):
617622
Returns self.
618623
"""
619624

625+
if class_weight != None:
626+
warnings.warn("'class_weight' is now an initialization parameter."
627+
"Using it in the 'fit' method is deprecated.",
628+
DeprecationWarning)
629+
self.class_weight = class_weight
630+
620631
X = atleast2d_or_csr(X, dtype=np.float64, order="C")
621632
y = np.asarray(y, dtype=np.int32).ravel()
622633
self._sparse = sp.isspmatrix(X)

0 commit comments

Comments
 (0)
0