8000 [MRG] Diabetes example with GridSearchCV (#8268) · maskani-moh/scikit-learn@0598a53 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0598a53

Browse files
rishikksh20maskani-moh
authored andcommitted
[MRG] Diabetes example with GridSearchCV (scikit-learn#8268)
1 parent 4e3aefe commit 0598a53

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

examples/exercises/plot_cv_diabetes.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from sklearn.linear_model import LassoCV
2020
from sklearn.linear_model import Lasso
2121
from sklearn.model_selection import KFold
22-
from sklearn.model_selection import cross_val_score
22+
from sklearn.model_selection import GridSearchCV
2323

2424
diabetes = datasets.load_diabetes()
2525
X = diabetes.data[:150]
@@ -28,19 +28,13 @@
2828
lasso = Lasso(random_state=0)
2929
alphas = np.logspace(-4, -0.5, 30)
3030

31-
scores = list()
32-
scores_std = list()
33-
31+
tuned_parameters = [{'alpha': alphas}]
3432
n_folds = 3
3533

36-
for alpha in alphas:
37-
lasso.alpha = alpha
38-
this_scores = cross_val_score(lasso, X, y, cv=n_folds, n_jobs=1)
39-
scores.append(np.mean(this_scores))
40-
scores_std.append(np.std(this_scores))
41-
42-
scores, scores_std = np.array(scores), np.array(scores_std)
43-
34+
clf = GridSearchCV(lasso, tuned_parameters, cv=n_folds, refit=False)
35+
clf.fit(X, y)
36+
scores = clf.cv_results_['mean_test_score']
37+
scores_std = clf.cv_results_['std_test_score']
4438
plt.figure().set_size_inches(8, 6)
4539
plt.semilogx(alphas, scores)
4640

0 commit comments

Comments
 (0)
0