-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Add 32 bit support to neural_network module #17700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I was under the impression that we decided to stop providing new features / improvements to the NN module? |
Stop providing features yes, performance improvements no, I think. As we do include it, people will use in some simple cases where installing pytorch/tensorflow is a hassle and there is no reason to keep our implementation slower that it can be. Here we are talking about easy changes to half a dozen lines( similar to #16352) with a clear performance improvement. |
@rth, I've tried passing clf = MLPClassifier(alpha=1e-5,
hidden_layer_sizes=(10, 5, 3),
random_state=1, max_iter=100,
dtype=np.float32) Internally, I cast all the network parameters to this
The dataset is something like this: X, y = make_classification(n_samples=10000, n_features=30, n_informative=15, n_classes=5, random_state=0) |
Great! I was hoping that it would be a bit more significant, I but I guess it also depends on the CPU vectorization support. Maybe let's not add the dtype parameter after all, but rather do as in #16352: run,
first and then use |
Sure @rth, will open a PR in a bit. |
Related to #11000 it would be good to add support for 32 bit computations for estimators in the
neural_network
module. This was done forBernoulliRBM
in #16352 Because performance is bound by the dot product this is going to have a large impact (cf #17641 (comment))I would even argue that unlike other models, it could make sense to add a
dtype=np.float32
parameter and make calculations in 32 bit by default regardless ofX.dtypes
. We could also consider supportingdtype=np.float16
.The text was updated successfully, but these errors were encountered: