8000 Use full dataset · scikit-learn/scikit-learn@9aec439 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9aec439

Browse files
committed
Use full dataset
1 parent 7b96ab1 commit 9aec439

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

examples/linear_model/plot_sparse_logistic_regression_mnist.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@
3232
from sklearn.preprocessing import StandardScaler
3333
from sklearn.utils import check_random_state
3434

35-
# Turn down for faster convergence
3635
t0 = time.time()
37-
train_samples = 5000
3836

3937
# Load data from https://www.openml.org/d/554
4038
X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
@@ -46,14 +44,14 @@
4644
X = X.reshape((X.shape[0], -1))
4745

4846
X_train, X_test, y_train, y_test = train_test_split(
49-
X, y, train_size=train_samples, test_size=10000, random_state=random_state
47+
X, y, test_size=0.2, random_state=random_state
5048
)
49+
train_samples, _ = X_train.shape
5150

5251
scaler = StandardScaler()
5352
X_train = scaler.fit_transform(X_train)
5453
X_test = scaler.transform(X_test)
5554

56-
# Turn up tolerance for faster convergence
5755
clf = LogisticRegression(C=50.0 / train_samples, penalty="l1", solver="saga", tol=0.1, random_state=random_state)
5856
clf.fit(X_train, y_train)
5957
sparsity = np.mean(clf.coef_ == 0) * 100

0 commit comments

Comments
 (0)
0