8000 [MRG] FIX: Allow coef_=None with warm_start=True in MultiTaskElasticNet by larsoner · Pull Request #12844 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] FIX: Allow coef_=None with warm_start=True in MultiTaskElasticNet #12844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
8000 Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,7 @@ def fit(self, X, y):
X, y, X_offset, y_offset, X_scale = _preprocess_data(
X, y, self.fit_intercept, self.normalize, copy=False)

if not self.warm_start or not hasattr(self, "coef_"):
if not self.warm_start or getattr(self, "coef_", None) is None:
self.coef_ = np.zeros((n_tasks, n_features), dtype=X.dtype.type,
order='F')

Expand Down
3 changes: 2 additions & 1 deletion sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ def test_multi_task_lasso_and_enet():
assert 0 < clf.dual_gap_ < 1e-5
assert_array_almost_equal(clf.coef_[0], clf.coef_[1])

clf = MultiTaskElasticNet(alpha=1.0, tol=1e-8, max_iter=1)
clf = MultiTaskElasticNet(alpha=1.0, tol=1e-8, max_iter=1, warm_start=True)
clf.coef_ = None # ensure that this is still supported with warm_start
assert_warns_message(ConvergenceWarning, 'did not converge', clf.fit, X, Y)


Expand Down
0