10000 WIP multilayer perceptron classifier by larsmans · Pull Request #1395 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

WIP multilayer perceptron classifier #1395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from

Conversation

larsmans
Copy link
Member

This is a very early PR, just to let you know I'm working on this (and that I beat @amueller to it ;) and to keep a TODO list for myself.

This PR contains a working implementation of a multilayer perceptron classifier trained with batch gradient descent. Features:

  • one hidden layer with logistic activation function
  • minibatch SGD with momentum method (rolling average of gradients)
  • multiclass cross-entropy loss
  • [preliminary] 0.878 F1-score on small version of 20news (n_hidden=400, alpha=0, max_iter=100, random_state=42, tol=1e-4, learning_rate=.02)
  • L2 penalty; doesn't seem to work very well, though, so turned off by default
  • tanh hidden layer activation

TODO:

  • reuse parts of SGDClassifier
  • hinge loss
  • multitarget/multilabel output
  • optimize further
  • adaptive learning rates?
  • tests
  • demos
  • documentation

Not TODO, at least not in this PR:

  • regression
  • conjugate gradient optimization
  • multiple hidden layers
  • L1 penalty
  • Student's t penalty
  • reuse @pprett's Cython code for datasets and weight vectors (took too much refactoring)
  • dropout regularization?

@amueller
Copy link
Member

@larsmans great that you started working on this!
Did you start from scratch? You could have just continued on the existing code...

@larsmans
Copy link
Member Author

Yes, I started entirely from scratch. I wasn't aware that there was code lying around. If you can post a pointer to it, I can see if I can scavenge useful tricks from it.

@amueller
Copy link
Member

There was some code by me and additions from David Mareks GSOC:
https://github.com/amueller/scikit-learn/blob/davidmarek_gsoc_mlp/sklearn/mlp

You should look up the initialization of the weights there, there is a better trick.

So you are doing full batches? With David's Cython version, there was a huge function call overhead and the next step would have been to use blas calls.
For that, a recent PR from Fabian needs to be merged to get gemmm.

@amueller
Copy link
Member

Have you tried mnist yet? What are the runtimes?

8000
@amueller
Copy link
Member

BTW, RPROP is really great ;) It is really easy to implement and gets rid of all learning rate issues. It is a batch method, though.

@amueller
Copy link
Member

BTW, I'd usually rather use tanh in the hidden layer.
You have cross-entropy loss in the todo but you already implemented that, didn't you? This is the same as multinomial logistic regression afaik.

@larsmans
Copy link
Member Author

No, I haven't tried MNIST yet. I was going to do minibatches; load, say, 100 or 1000 samples using @pprett's code, then just use np.dot from there. That should get rid of the function call overhead and at the same time leverage whatever BLAS Numpy uses.

I considered IRPROP, but Hinton, in his Coursera course, mentioned that there's now a minibatch version of RPROP called RMSPROP. I'm not sure if that's even published yet, but I'll check it out. It's apparently related to the "No more pesky learning rates" algorithm, which I still have to read.

As for tanh, I know they have larger gradients and all, but in the long run I'm interested in warm starting these things from RBMs to do semisupervised and multitask learning, and RBMs always use logistic hidden units.

I have hinge loss in the TODO. Apparently, LeCun uses (or used to use) that, so it must be good :)

I'll merge your code and mine.

@ogrisel
Copy link
Member
ogrisel commented Nov 23, 2012

+1 for focusing quickly on minibatches (for out of core large scale learning), AdaGrad for easy, per-feature adaptative learning rates and random dropouts that seems to completely remove the need for early stopping to combat overfitting and allows deeper neural nets to converge to a better solution.


8000
@ogrisel
Copy link
Member
ogrisel commented Nov 23, 2012

I am not sure about which is better between AdaGrad and one of the per-feature no-more-pesky-LR strategies. AdaGrad seems much simpler to implement. In AdaGrad there there is still a constant hyperparam for the learning rate but I think it is less sensitive and can in practice be approximately grid searched cheaply in a burnin phase on the first 10k samples for instance.

For the no-more-pesky learning rate there is a need to compute an online estimate of the diagonal of the hessian and I have no idea whether this is easy to implement or not.

@ogrisel
Copy link
Member
ogrisel commented Nov 23, 2012

Also for the minibatch implementation, it's very important to pre-allocate the memory of the output buffer of each layer outside of the main fit loop.

@ogrisel
Copy link
Member
ogrisel commented Nov 23, 2012

Also does someone know the relationship between:

  • the moment method (gradient rolling averaging),
  • adagrad (sqrt(sum of per feature gradients) scaling of the learning rate
  • (possibly rolling) Polyak-Ruppert averaging of the model it-self out of the main SGD loop

In particular if one strategy is a proxy / equivalent to another.

@ogrisel
Copy link
Member
ogrisel commented Nov 23, 2012

Also this PR could give the opportunity to start using the cython memory view (and maybe fused types too). memory views could potentially give a perf boost too.

@amueller
Copy link
Member

Adagrad basically does the exact opposite of momentum.... Adagrad makes smaller updates in directions that have been frequently updated while momentum keeps going faster in these directions.

In my lab we didn't have much success with adagrad for MLPs. Our hunch is that it might be good for text data with very different frequencies but not for image data or dense data. I haven't tried the no more pesky learning rates yet, as it seems a bit of a hassle to implement.

I am not sure if drop outs will help much with only one hidden layer btw.

@larsmans The problem with np.dot is that it gives you the full python overhead, which can matter if the batches are small. You can see that if you benchmark the other code.

@amueller
Copy link
Member

Memory views mainly help for looping over the items. Here, it is all about the matrix multiplication. If we use np.dot, it will always go via python objects...

@amueller
Copy link
Member

Btw, what is the argument for minibatch vs pure online learning (1 example at a time)?
In my experience the optimization is faster (in terms of weight updates) the smaller the batch. On GPUs there is a tradeoff because larger matrix multiplications are much faster (runtime) than many small ones.
I am not sure how that looks on CPUs.
Any ideas?

@@ -0,0 +1 @@
from .mlperceptron import MLPClassifier
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use neural_network without an s to be consistent with linear_model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, but I suggest we postpone this discussion until the code is more or less finished. This is the module name that the RBM PR currently uses.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather do it now so as to not forget.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I myself am pro sklearn.neural. I certainly won't forget this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I myself am pro sklearn.neural. I certainly won't forget this.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that a +1 on neural, or on handling this later?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that a +1 on neural, or on handling this later?

Sorry: +1 on neural

@ogrisel
Copy link
Member
ogrisel commented Nov 23, 2012

About minibatches vs pure online

Using minibatches makes it possible to use blas (d/s)gemm (both for forward prop / predictions and backprop of the gradient) which can be much faster to compute than batch_size times (s/d)axpy as it's better able to use SSE3 vector operations (when you use an optimized blas implementation such as atlas / openblas).

Pure online (one example at a time) is better from a pure algorithm perspective if CPU scalar and vector would have the same speed. So there is a trade-off between pure theoretical algorithmic convergence speed and implementation speedup caused by the larger data parallelism that comes from larger minibatches.

Break-even for MLP implemented on GPU is around batch_size=100 to 500 if I remember correctly (@dwf can you confirm?). I suppose that in our case for a CPU implementation with SSE3 operations is should be close the to the batch_size=10 to 50 range.

About AdaGrad

Jeff Dean's paper on scaling very large autoencoders on dense features for image data (the traditional local receptive field filter / pool / normalize) shows that adagrad yields very significative convergence improvements. These data & models are very different from the original motivation for AdaGrad which was linear models on large sparse high dimensional data with rare yet very informative features. However if you have opposite practical feedback on smaller models than the ones by Google I can believe you too.

About memory views

I think memory views would also make it more natural to directly call blas operations. In the coordinate descent cython code we use casts + pointer arithmetic that would probably go away with memory views (I have never tried to use them myself yet though). Maybe @jakevdp has an opinion on this.

Also +1 for not using np.dot but calling blas (s/daxpy or better d/gemm with minibatches) directly to get rid of any python interpreter & input checks overhead in the inner loop.

@ogrisel
Copy link
Member
ogrisel commented Nov 23, 2012

Also I think @npinto has had good experience with piece-wise linear approximation of the logistic activation function (to get rid of the expensive log / exp operations). Could you confirm?

@amueller
Copy link
Member

I don't think the log/ exp play a role compared to the blas calls.
I completely agree with the memory views: they should make the blas calls much easier. I was trying to say that they don't help as long as we are using np.dot.

@ogrisel
Copy link
Member
ogrisel commented Nov 23, 2012

I agree for postponing any optim of log / exp to after the rest (minibatch / blas optims for the dot products).

@larsmans
Copy link
Member Author

@amueller I haven't done full profiling, but I don't think log/exp is currently that heavy compared to the O(n³) cost of matrix multiplication. The reason for minibatches is indeed to leverage BLAS optimizations.

@ogrisel The problem with calling BLAS directly is that the BLAS will be called that we build against, so you'd lose the abilitity to plug in a new implementation. However, if it turns out Python overhead is really that bad, I suggest we borrow code from Tokyo to do this; I realize it stands in the way of OpenMP which might be something to try here, too.

I was just looking at memoryviews :)

@larsmans
Copy link
Member Author

@ogrisel: I'm surprised that no more pesky LR involves the diagonal of the Hessian; such a strategy is mentioned by Bishop, but he notes that "in practice the Hessian is typically found to be strongly nondiagonal".

@ogrisel
Copy link
Member
ogrisel commented Nov 23, 2012

@ogrisel: I'm surprised that no more pesky LR involves the diagonal of the Hessian; such a strategy is mentioned by Bishop, but he notes that "in practice the Hessian is typically found to be strongly nondiagonal".

It might but it's just non practical to store a full n_parameters ** 2 hessian in memory.

@larsmans
Copy link
Member Author

I understand that, and I'm certainly not suggesting to go quasi-Newton here. I fear getting all this to work right might require reading more of Nocedal than I have time for, so I'd rather focus on implementing the easy tricks right for now. With a proper modular setup, we could plug in a different optimizer later.

@briancheung
Copy link
Contributor

Also you need to invert that hessian matrix for sgd. The diagonal approximation is discussed in "Efficient Backprop" by LeCun et al. They also mention a lot of the standard tricks for transforming the features for neural nets to make them easier to train. In Section 7.3, you approximate the hessian with the square of the jacobian (gauss newton) and take the diagonal of that result.

I agree that the optimizer should be modular as there are probably going to be quite a few accepted ways on dealing with the learning rate. I think you're going to have to tie the hessian calculations with your backpropagation method, so you might have to think carefully on how to make this modular.

@dwf
Copy link
Member
dwf commented Nov 25, 2012

Something to also look at is Nesterov momentum vs. classic momentum. A thorough treatment of the former can be found in Ilya Sutskever's thesis, chapter 7.

@npinto
Copy link
Contributor
npinto commented Nov 26, 2012

Also I think @npinto has had good experience with piece-wise linear approximation of the logistic activation function (to get rid of the expensive log / exp operations). Could you confirm?

Hey guys, not sure what I can contribute here, except maybe that (from our experience, unpublished) we have seen that:

  1. the simplest half-rectified linear activation works best (confirmed by many others since)
  2. fmin_l_bfgs_b from scipy.optimize gets better (and faster) results than custom-made Hessian-free stuff
  3. if you are interested in 'dropout', it looks like it does not add much value compared to bagging (with warm-restart, "ASGD-like" to speed up the convergence) or adding different types of noise (like denoising auto-encoders, or adding gaussian noise for L2 regularization).

Bottom line: simple things work (and scale) much better, without any fuss.

(sorry for the brief tangent, I just wanted to share this with the scikit-learn community)

I'm happy to discuss further and I should probably write these things down... oh well...

HTH

9365
@amueller
Copy link
Member

@npinto Thanks for sharing your insights and experience!

@npinto
Copy link
Contributor
npinto commented Nov 26, 2012

On: "Nesterov momentum / subgradients", Polyak Averaging (btw, their books are awesome [1, 2]), AdaGrad, Hessian approx, etc.

It looks (again from our experience) that learning with "temporal coherence" will get similar performance "boosts". Even better: with the right "natural" data you can learn the "pooling" (e.g. in conv nets) without hardcoding it with a different hacky operation ;-) If one looks deeper, it kind of makes sense (at least intuitively) as they all have a very similar flavor...

Maybe some theory is missing? But nothing new here on the empirical side, these simple ideas have been exposed in the 80s and 90s but never exploited at scale or sometimes replaced with a different nametag attached and a diluted/confusing explanation.

[1] Introductory Lectures on Convex Optimization - A Basic Course (Nesterov, 2004)
[2] Introduction to Optimization (Polyak, 1987)
(djvu files are online)

@larsmans larsmans mentioned this pull request Feb 4, 2013
Number of units in the hidden layer.
activation: string, optional
Activation function for the hidden layer; either "logistic" for
1 / (1 + exp(x)), or "tanh" for the hyperbolic tangent.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding Rectified Linear units? They should be much faster to compute than transcendental activation functions (and could be regularized with a piecewise linear regularizer as well: http://arxiv.org/abs/1301.3577 even if introduced for autencoders in this paper)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, some code for ReLUs and softplus is already in place, but not mentioned here yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed although it's currently commented out. Why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they don't even typecheck ATM.

@ogrisel
Copy link
Member
ogrisel commented Feb 5, 2013

It would be great to refactor the code to be able to implement a public partial_fit method for minibatches. To make this efficient, it would be great to preallocate the temporary buffers out of the cython help during the first call to partial_fit and store those buffers as private attributes of the python estimator class to be reused without new call to malloc for the following calls to partial_fit (unless a benchmark proves that this is a useless optimization).

@temporaer
Copy link
Contributor

For this refactoring, a major question I believe is, are we going to support multiple hidden layers (I'd conjecture dropout will mainly help with >1)?

Should we model the layers as classes, so they can be extended (dropout again, but also the half-rectified linear function mentioned by @npinto)?

Do we want to model the weight update as a class, so it can be extended (->RPROP, momentum, AdaGrad, RMSProp, ...?)

@larsmans
Copy link
Member Author
larsmans commented Feb 5, 2013

I don't want multiple layers for now -- let's get it pulled first and extend it later. I like the idea of weight update classes. I considered rmsprop but rejected it because it introduced yet another hyperparameter. With class instances encapsulating them, that should be less of a problem.

@ogrisel I'm not quite sure what you mean; do you want to feed batches of, say, 100 samples per call to partial_fit? The idea now is that minibatch learning is available without extra cost. So you can extract 10k samples at a time, or however many you can fit in RAM, then feed those to partial_fit in a single call and have the SGD implementation divide that into much smaller minibatches for you. That should avoid a lot of Python overhead in the feature extraction/selection/transformation.

@ogrisel
Copy link
Member
ogrisel commented Feb 6, 2013

The idea now is that minibatch learning is available without extra cost. So you can extract 10k samples at a time, or however many you can fit in RAM, then feed those to partial_fit in a single call and have the SGD implementation divide that into much smaller minibatches for you. That should avoid a lot of Python overhead in the feature extraction/selection/transformation.

That's an interesting pragmatic trade-off. Although it make it more complicated for the user what is the actual learning batch size (amounts of samples used for each subgradient computation + weight update). But fair enough.

@temporaer
Copy link
Contributor

@larsmans , can you please share the file you used for testing with the 20news dataset (assuming I didn't miss it)?

Also, I made quite some progress on the refactoring side and would welcome feedback.
Feel free to browse through my mlperceptron branch https://github.com/temporaer/scikit-learn/tree/mlperceptron
Assuming it is useful, shall I add a pull-request to your branch, or how is this usually done?

@larsmans
Copy link
Member Author
larsmans commented Mar 5, 2013

Example script is now a gist: https://gist.github.com/larsmans/5089191

And yes, please send PRs to my repo.

@temporaer
Copy link
Contributor

Do you have comments on how to do early stopping, especially splitting the dataset, in the classifier? Is there/will there be framework support/guidelines for this?

I guess it would have to have these components, which are (too?) tightly entangled with the main training loop:

  1. split dataset into train/val
  2. in the training loop, support calling predict() on the validation set (this could be done with larger batch size for the MLP, to speed things up)
  3. measure loss using a possibly different loss function (e.g. zero-one loss)
  4. if the early-stopping loss function improves, serialize the classifier and possibly increase some patience variable (patience=epoch*2), otherwise decrease patience (patience-=1)
  5. if run out of patience, break training loop and deserialize classifier from best value

@amueller
Copy link
Member
amueller commented Mar 5, 2013

For my point of view, see #1626

@larsmans
Copy link
Member Author
larsmans commented Mar 5, 2013

I suggest we leave early stopping out of the PR.

@IssamLaradji
Copy link
Contributor

Hi @larsmans , excellent MLP implementation 👍 , I'm wondering whether you would consider adding weight initialization hyperparameter. This is because estimating near optimal weights could lead to faster convergence and perhaps better prediction performance. For example, one might set the initial weights by the weights generated using the Restricted Boltzmann Machines. :)

Thanks

@larsmans
Copy link
Member Author
larsmans commented Jun 5, 2013

Yes, bootstrapping from an RBM is planned, but not initially. I haven't had enough time to finish the features I have planned currently, let alone implement additional ones.

@IssamLaradji
Copy link
Contributor

Thanks @larsmans , I installed your code in my local machine, but I when I tried fitting the MLPClassifier to matrices X,y I get the error at [1],

Both X and y have the same number of rows, their shapes are,

In [8]: X.shape
Out[8]: (32769L, 9L)
In [9]: y.shape
Out[9]: (32769L,)

[1] The error

ValueError                                Traceback (most recent call last)
<ipython-input-7-12b5b6bd5106> in <module>()
----> 1 clf.fit(X,y)

C:\Anaconda\lib\site-packages\sklearn\neural\mlperceptron.pyc in fit(self, X, y)
    116         print X.shape
    117         print y.shape
--> 118         backprop_sgd(self, X, Y)
    119 
    120         return self

C:\Anaconda\lib\site-packages\sklearn\neural\backprop_sgd.pyd in sklearn.neural.backprop_sgd.backprop_sgd (sklearn\neural\backprop_sgd.c:5252)()

ValueError: matrices are not aligned

Thanks!

@larsmans
Copy link
Member Author

The code is not finished and cannot be expected to work ATM.

@IssamLaradji
Copy link
Contributor

Keep up the amazing work @larsmans :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0