8000 [MRG] Multi-layer perceptron (MLP) by IssamLaradji · Pull Request #2120 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Multi-layer perceptron (MLP) #2120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

IssamLaradji
Copy link
Contributor

Multi-layer perceptron (MLP)

PR closed in favor or #3204

mlp

This is an extention to larsmans code.

A multilayer perceptron (MLP) is a feedforward artificial neural network model that tries to learn a function f(X)=y where y is the output and X is the input. An MLP consists of multiple layers, usually of one hidden layer, an input layer and an output layer, where each layer is fully connected to the next one. This is a classic algorithm that has been extensively used in Neural Networks.

Code Check out :

  1. git clone https://github.com/scikit-learn/scikit-learn
  2. cd scikit-learn/
  3. git fetch origin refs/pull/2120/head:mlp
  4. git checkout mlp

Tutorial link:

- http://easymachinelearning.blogspot.com/p/multi-layer-perceptron-tutorial.html

Sample Benchmark:

- `MLP` on the scikit's `Digits` dataset gives, - Score for `tanh-based sgd`: 0.981 - Score for `logistic-based sgd`: 0.987 - Score for `tanh-based l-bfgs`: 0.994 - Score for `logistic-based l-bfgs`: 1.000

TODO:

- Review

from ..utils.extmath import logsumexp, safe_sparse_dot


def validate_grad(J, theta, n_slice):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could move this into the test file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering where to put it. Thanks!

@amueller
Copy link
Member
amueller commented Jul 1, 2013

Could you please say in how far this extends @larsmans PR? This one seems to be completely in Python, while I remember @larsmans's to be in Cython, right? I'm not completely sure how large the benefit of Cython was, though.
Does this one support sparse matrices? cc @temporaer

@IssamLaradji
Copy link
Contributor Author
  1. @amueller, larsmans missing part was the backpropagation which was partly developed in Cython, I developed that part using vectorized matrix operations, its quite fast, for example, running the algorithm on the 'digits' dataset for 400 iterations with 50 hidden neurons, would take about 5 seconds. I also added the option of using of a secondary optimization algorithm 'fmin_l_bfgs_bf' which is as fast, but achieves better classification performance with the same number of iterations. I'm also thinking of adding a third option: 'fmin_cg'. These optimizers are somewhat heavily used in neural networks.

I read that Cython code is easy to produce (just a matter of adding some prefixes and compiling the code). I will Cython the code and see if it adds benefits.

  1. Yes. It supports sparse matrices (via safe_sparse_dot)

Thanks for the review

activation: string, optional
Activation function for the hidden layer; either "sigmoid" for
1 / (1 + exp(x)), or "tanh" for the hyperbolic tangent.
_lambda : float, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call this alpha almost everywhere else. Don't use a leading underscore on a parameter.

@larsmans
Copy link
Member
larsmans commented Jul 1, 2013

I found that switching to Cython gave about an order of magnitude improvement over pure Python. We can merge this version as an intermediate, it looks every clean. How fast is it on 20newsgroups w/ 100 hidden units?

@amueller
Copy link
Member
amueller commented Jul 1, 2013

Really? Why? In the other implementation, there was no gain at all. I guess you used one sample "mini-batches"?

@larsmans
Copy link
Member
larsmans commented Jul 1, 2013

No, my implementation would take large batches, divide these into randomized minibatches internally (of user-specified size), then train on those. That gave much faster convergence, without the need to actually materialize the minibatches (no NumPy indexing).

@amueller
Copy link
Member
amueller commented Jul 1, 2013

Ok, that makes sense and explains why the cython version is much faster.

@IssamLaradji
Copy link
Contributor Author

@larsmans, using the whole 20 categories of 20news (not the watered down version) modeled by tf-idf scikit vectorizer, yielding an Input matrix of 18828 rows and 74324 columns aka features, and with 100 hidden neurons, the algorithm fitted on the whole sparse matrix with around 1 second per iteration. It seems like a good enough speed for MLP for such large data, but I might be wrong.

@larsmans
Copy link
Member
larsmans commented Jul 1, 2013

What's the F1 score, and how many iterations are needed for it? (I got somewhat faster results, but I wonder if LBFGS converges faster than SGD.)

@IssamLaradji
Copy link
Contributor Author

Right now, I applied the code on 4 categories of the 20news corpus, with 100 iterations and 100 hidden neurons, l_bfgs achieved an average f1-score of 0.87. I might need to leave the code run for a long time before it converges (it doesn't converge even after 300 iteration), thus, I suspect there is a bug in my code.

In your pull request you mentioned that you tested your code on a small subset of 20news corpus achieving similar results, did you use 4 categories too?

inds = np.arange(n_samples)
rng.shuffle(inds)
n_batches = int(np.ceil(len(inds) / float(self.batch_size)))
# Transpose improves performance (from 0.5 seconds to 0.05)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improves the performance of what? For dense or sparse data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The performance improved in calculating the cost and the gradient. It was observed on dense data, didn't try it on sparse yet.

It might look peculiar, but it has something to do with the matrix multiplications. I just played with the timeit library to understand the performance increase. I found that if for example you multiply matrices A and B, together, while assuming the time it takes is 0.25 ms for such multiplication, then multiplying B.T with A.T could take twice as long, that is 0.5 ms. So, that small time increase will add up if the cost and gradient is calculated multiple times. In other words, multiplying matrices with different shapes could incur time overheads.

I will commit the non-transposed function to benchmark again, just to be safe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not too surprising; probably due to Fortran vs. C arrays (column-major or row-major). Input to np.dot should ideally be a C and a Fortran array, in that order, IIRC.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But anyway, my point was that performance figures out of context don't belong in code :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll have such comments removed soon :)

@larsmans
Copy link
Member
larsmans commented Jul 3, 2013

I have some MLP documentation lying around, I'll see if I can dig that up, too.

@@ -0,0 +1,617 @@
"""Mulit-layer perceptron
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo

@IssamLaradji
Copy link
Contributor Author

Sorry for being slow in responding, I had a bug in the code which took time to fix because the transposed X was confusing everything :). I had a weird benchmark that made me think that X.T improved performance, but in reality it did not, so I removed the transpose, making the code cleaner and easier to debug while the performance unchanged.

Moreover, I just committed a lot of changes, including,

  • Optimization_method parameter for selecting any scipy optimizer
  • Support of SGD
  • Improved minibatch creation using scikit's gen_even_slices
    • (much faster than X[inds[minibatch::n_batches]])
  • Support of loss functions cross-entropy and square (more will be added)
  • Typos and name fixes

The performance benchmark on the digits dataset (100 hidden neurons and 170 iterations),

  • SGD with cross-entropy loss
    • Score : 0.95
  • Optimization using CG aka Congruent Gradient with cross-entropy loss
    • Score : 0.95
  • Optimization using l-bfgs with square loss
    • Score : 0.98 (it has converged in 80 iterations)
  • Please note that the score is worse when the loss is square for SGD and CG.

Will post the test results on the 20News dataset soon.

Some of the remaining TODO's would be:

  • Use sqrt(n_features) to select the number of hidden neurons
  • Update the documentation
  • Add verbose
  • Add a test file
  • Add an example file

Thank you for your great reviews!

@larsmans
Copy link
Member
larsmans commented Jul 5, 2013

You can get a unit test by

git remote add larsmans git@github.com:larsmans/scikit-learn.git
git fetch larsmans
git cherry-pick 3176ce315b38176484fd57357496f8a6a0589071

I'd send you a PR, but it looks like the GitHub UI changes broke PRs between forks.

@IssamLaradji
Copy link
Contributor Author

@larsmans the unit test is very useful! I will renovate the code as per the comments.

Thanks for the review!

@IssamLaradji
Copy link
Contributor Author

Updates

- Replaced scipy `minimize` with`l-bfgs` - So not to compel users to install scipy 13.0+ - Renames, as per the comments, done - Divided the function to `MLPClassifier` and `MLPRegressor` - Set `square_loss` default for `MLPRegressor` - Set `log` default for `MLPClassifier` - Fixed long lines (some lines might still be long) - Added learning rates, that include, `optimal`, `constant`, and `invscaling` - New Benchmark on the `digits` dataset (100 iterations, 100 hidden neurons, `log` loss) - `tanh-based SGD` : `0.957` - `tanh-based l_bfgs` : `0.985` - `logistic-based SGD` : `0.992` - `logistic-based l_bfgs` : `1.000` (converged in `70` iterations)

These are interesting results because tanh should make the algorithm converge faster than logistic. I suspect a bug in computing the deltas (line 454 to 466) that lead to these obscure results. I'll use the unit test to ensure the backpropagation is working as necessary.

The documentation will be updated once the code is deemed reliable.

Thank you for your great reviews and tips! :)

-------
x_new: array-like, shape (M, N)
"""
X *= -X
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to indicate in a very visible way in the docstring that the input data is modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GaelVaroquaux , so it would be better to write Returns the value computed by the derivative of the hyperbolic tan function instead of Computes the derivative of the hyperbolic tan function in line 100?

or Modifies the input 'x' via the computation of the tanh derivative ?

thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you make all of these private (preprend an _ to the name) then a single comment above them would be enough, I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.
On another note, is there a way to beat the travis build? Does it fail because the test cases are not established yet? Thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Travis runs a bunch of tests on the entire package, including your code because it detects a class inheriting from ClassifierMixin. You should inspect the results to see what's going wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the errors are very clear now :)

@GaelVaroquaux
Copy link
Member

The neural_network sub-package needs to be added to the setup.py of sklearn. Elsewhere it does not get copied during the install.

}
self.activation = activation_functions[activation]
self.derivative = derivative_functions[activation]
self.output_func = activation_functions[output_func]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters of `init' should not be changed. Same for random_state below. See http://scikit-learn.org/stable/developers/index.html#apis-of-scikit-learn-objects

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luoq thanks for pointing that out, I fixed the initialization disagreements and pushed the new code.

~Issam

@IssamLaradji
Copy link
Contributor Author

Hi everyone, sorry for being inactive in this, it's been a laborious 2 weeks :). I have updated the code by improving the documentation and eliminating tanh related problems. As tanh can yield negative values, applying the log function on that produces an error, so I added code that scales the value in [0,1] range to ensure it being positive via 0.5*(a_output + 1)

The code seems to be accepted by the travis build, however, MLPRegressor is yet to be implemented, but will be done soon.

PS: I'm also writing a blog that aims in helping 'newcomers to the field' (maybe practitionars even) engage in Neural Networks
http://easydeeplearning.blogspot.com/p/multi-layer-perceptron-tutorial.html

Thanks in advance!

@arjoly
Copy link
Member
arjoly commented Aug 2, 2013

Instead of using an abreviation for MLP, why not write plainly MultilayerPerceptron? Thus you would get more readable class name MultilayerPerceptronClassifier, MultilayerPerceptronRegressor and BaseMultilayerPerceptron.

@larsmans
Copy link
Member
larsmans commented Aug 2, 2013

Also, can you rebase on master? We merged RBMs, so there's a neural_network module now.

from itertools import cycle, izip


def _logistic(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In master, there's a fast logistic function in sklearn.utils.extmath.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will plug it in!

@IssamLaradji
Copy link
Contributor Author

@arjoly, the naming is a good idea thanks!
@larsmans sure I will have it rebased

@IssamLaradji
Copy link
Contributor Author

Okay done :), I have fixed the binary classification, I'm getting 100% score with logistic as well as tanh on a binary dataset generated using the Digits scikit's repository. It turns out that I had to apply logistic on the output layer regardless of the activation function in the hidden layer, and the loss function is
-np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

Gladly, it passed the travis test, now what is left is to re-use some of scikit's cython-based loss functions (and logistic) for improved speed and implement MLP for regression.

In addition, the packing and unpacking methods are to be improved.

@IssamLaradji
Copy link
Contributor Author

Hi guys, I am closing this pull-request because of the very long discussion.
Here is the new pull request: #3204.

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0