8000 Merge pull request #5214 from glennq/mlp_refactoring · scikit-learn/scikit-learn@965a715 · GitHub
[go: up one dir, main page]

Skip to content

Commit 965a715

Browse files
committed
Merge pull request #5214 from glennq/mlp_refactoring
[MRG + 2] Mlp with adam, nesterov's momentum, early stopping
2 parents efb0179 + 917bacb commit 965a715

18 files changed

+3051
-23
lines changed

benchmarks/bench_mnist.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,21 @@
99
covertype dataset, the feature space is homogenous.
1010
1111
Example of output :
12-
1312
[..]
13+
1414
Classification performance:
1515
===========================
16-
Classifier train-time test-time error-rat
16+
Classifier train-time test-time error-rate
1717
------------------------------------------------------------
18-
Nystroem-SVM 105.07s 0.91s 0.0227
19-
ExtraTrees 48.20s 1.22s 0.0288
20-
RandomForest 47.17s 1.21s 0.0304
21-
SampledRBF-SVM 140.45s 0.84s 0.0486
22-
CART 22.84s 0.16s 0.1214
23-
dummy 0.01s 0.02s 0.8973
24-
18+
MLP_adam 53.46s 0.11s 0.0224
19+
Nystroem-SVM 112.97s 0.92s 0.0228
20+
MultilayerPerceptron 24.33s 0.14s 0.0287
21+
ExtraTrees 42.99s 0.57s 0.0294
22+
RandomForest 42.70s 0.49s 0.0318
23+
SampledRBF-SVM 135.81s 0.56s 0.0486
24+
LinearRegression-SAG 16.67s 0.06s 0.0824
25+
CART 20.69s 0.02s 0.1219
26+
dummy 0.00s 0.01s 0.8973
2527
"""
2628
from __future__ import division, print_function
2729

@@ -48,6 +50,7 @@
4850
from sklearn.tree import DecisionTreeClassifier
4951
from sklearn.utils import check_array
5052
from sklearn.linear_model import LogisticRegression
53+
from sklearn.neural_network import MLPClassifier
5154

5255
# Memoize the data extraction and memory map the resulting
5356
# train / test splits in readonly mode
@@ -84,11 +87,19 @@ def load_data(dtype=np.float32, order='F'):
8487
'CART': DecisionTreeClassifier(),
8588
'ExtraTrees': ExtraTreesClassifier(n_estimators=100),
8689
'RandomForest': RandomForestClassifier(n_estimators=100),
87-
'Nystroem-SVM':
88-
make_pipeline(Nystroem(gamma=0.015, n_components=1000), LinearSVC(C=100)),
89-
'SampledRBF-SVM':
90-
make_pipeline(RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100)),
91-
'LinearRegression-SAG': LogisticRegression(solver='sag', tol=1e-1, C=1e4)
90+
'Nystroem-SVM': make_pipeline(
91+
Nystroem(gamma=0.015, n_components=1000), LinearSVC(C=100)),
92+
'SampledRBF-SVM': make_pipeline(
93+
RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100)),
94+
'LinearRegression-SAG': LogisticRegression(solver='sag', tol=1e-1, C=1e4),
95+
'MultilayerPerceptron': MLPClassifier(
96+
hidden_layer_sizes=(100, 100), max_iter=400, alpha=1e-4,
97+
algorithm='sgd', learning_rate_init=0.2, momentum=0.9, verbose=1,
98+
tol=1e-4, random_state=1),
99+
'MLP-adam': MLPClassifier(
100+
hidden_layer_sizes=(100, 100), max_iter=400, alpha=1e-4,
101+
algorithm='adam', learning_rate_init=0.001, verbose=1,
102+
tol=1e-4, random_state=1)
92103
}
93104

94105

87.3 KB
Loading

doc/modules/classes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,8 @@ See the :ref:`metrics` section of the user guide for further details.
10481048
:template: class.rst
10491049

10501050
neural_network.BernoulliRBM
1051+
neural_network.MLPClassifier
1052+
neural_network.MLPRegressor
10511053

10521054

10531055
.. _calibration_ref:

0 commit comments

Comments
 (0)
0