8000 MRG: Faster spatial trees with enhancements by jakevdp · Pull Request #1732 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MRG: Faster spatial trees with enhancements #1732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

jakevdp
Copy link
Member
@jakevdp jakevdp commented Mar 3, 2013

New Features

  • The BallTree class is enhanced to be compatible with 20 different distance metrics. There are also depth first, breadth first, single tree, and dual tree algorithms for the queries. The interface is designed to be fully backward compatible with the old code: all test scripts from previous versions pass on this new version.
  • The same core binary tree code which implements the Ball Tree is used to also implement a KD Tree, which can be faster in low dimensions for some datasets.
  • Both BallTree and KDTree now have a method for performing fast kernel density estimates in O[N log(N)] rather than the naive O[N^2]. Six common kernels are
    implemented, and each kernel can work with any of the valid distance metrics.
  • Both BallTree and KDTree have a method for computing fast 2-point correlation functions, also in O[N log(N)] rather than the naive O[N^2]. This works with any of the valid distance metrics.
  • As a bonus, there is a set of class interfaces to the various distance metrics, with functions to generate a pairwise distance matrix under any of the metrics.
  • All functions are documented and tested.

Remaining Tasks

These tasks can be added to this PR, or saved for a later one. I wanted to get some feedback on the changes before going forward with these:

  • Update the NearestNeighbors classes to use the available metrics, as well as to use the new KDTree class
  • Rework the narrative documentation for nearest neighbors to reflect these changes.
  • Create estimator classes to wrap the kernel density estimator functionality, and perhaps also the 2-point correlation function

Speed

The the new code is generally faster than the old ball tree implementation and the recent enhancement to the scipy KDTree. Here are some runtimes in seconds for a 2000 pt query on my slightly old linux laptop. The points are in 3 dimensions, with each dimension uniformly distributed between 0 and 1. These timings will change in hard-to-predict ways as the distribution of points is changed. Uniform point distributions in low dimensions tend to favor the KD Tree approach.

2000 points, metric = euclidean
            Build       KNN         RNN         KDE         2PT        
cKDTree      0.00022     0.014      
old BallTree 0.00072     0.023       0.022      
BallTree     0.0008      0.013       0.011       0.043       0.26       
KDTree       0.0011      0.008       0.02        0.031       0.31       

2000 points, metric = minkowski (p=3)
            Build       KNN         RNN         KDE         2PT        
cKDTree      0.00021     0.0088     
old BallTree 0.00095     0.029       0.022      
BallTree     0.00089     0.017       0.012       0.038       0.18      
KDTree       0.0011      0.0095      0.013       0.023       0.2       

@ogrisel
Copy link
Member
ogrisel commented Mar 3, 2013

This looks amazing Jake. Now I wonder when I will be able to find the time to review all of this... :) How do you plan to expose the new Kernel Density Estimation features of those data-structures? New estimators with a fit + 1D predict_proba method? Maybe the predict_proba method of classifiers would be a confusing API. A maybe a new (negative_log_)likelihood method on the existing neighbors model?

@ogrisel
Copy link
Member
ogrisel commented Mar 3, 2013

BTW, are the KDE methods faster alternative to scipy.stats.kde.gaussian_kde that can be used to smooth 1D or 2D distribution plots?

@ogrisel
Copy link
Member
ogrisel commented Mar 3, 2013

Also have you done benchmark in higher dimensions (e.g. 30 to 100)?

@ogrisel
Copy link
Member
ogrisel commented Mar 3, 2013

BTW: the travis failure seems to be a real test failure.

@jakevdp
Copy link
Member Author
jakevdp commented Mar 3, 2013

Regarding the KDE API: I think it would benefit us to come up with a good general API for density estimation within scikit-learn. There are many algorithms which are based on estimating density (nonparametric/naive Bayes, generative classification, etc.), and there are several approaches which are useful (KDE, GMM, bayesian k-neighbors methods, etc. Take a look at the density estimation routines in astroML) A uniform density estimation interface for all these would aid in using and comparing them, as well as to help make it easy to build a set of nonparametric Bayes classifiers which use the various methods.

The speed of the KDE compared to the scipy gaussian_kde depends on several things. Scipy's function is O[N^2] every time. This version can be as bad as O[N^2] with a very large kernel and a very small tolerance: even in that worst case, however, it's about a factor of a few faster than gaussian_kde. As the kernel width is decreased and the tolerance is increased, the number of evaluations theoretically will get to O[N log(N)], or even approach nearly O[N] using the dual tree approach, so the speed gain should be even more significant. Another difference is that scipy's gaussian_kde seems is written for statisticians: it includes built-in bandwidth estimation using several approaches common in the statistics community. I've done a more functional implementation common in the machine learning community, and I assume the user will implement their own bandwidth selection routine, through, e.g. cross-validation.

I haven't yet done detailed benchmarks on data dimension, data size, etc. but I hope to do a blog post on that in the near future.

The test failure is related to canberra distance. The distances are tested by comparing to the results to those of scipy.spatial.distance.cdist(). The canberra distance was implemented incorrectly before scipy version 0.10 (see scipy/scipy@32f9e3d). I believe the jenkins build uses scipy 0.9 currently, so that would lead to the errors. What's the best way to deal with that?

@ogrisel
Copy link
Member
ogrisel commented Mar 3, 2013

Thanks for the explanation. Looking forward to your blog post.

I think we should add a CVed KDE utility function by default: most of the time the user don't want to implement the CV by hand just to plot a smoothed distribution estimation.

The test failure is related to canberra distance. The distances are tested by comparing to the results to those of scipy.spatial.distance.cdist(). The canberra distance was implemented incorrectly before scipy version 0.10 (see scipy/scipy@32f9e3d). I believe the jenkins build uses scipy 0.9 currently, so that would lead to the errors. What's the best way to deal with that?

Check the value of scipy.__version__ in the test and raise a nose.tools.SkipTest in case it's too old.

@jakevdp
Copy link
Member Author
jakevdp commented Mar 4, 2013

I think we should add a CVed KDE utility function by default: most of the time the user don't want to implement the CV by hand just to plot a smoothed distribution estimation.

Yes, I completely agree. When I said above that I wanted to leave cross-validation to the user, I had in mind that the "user" might be a wrapper function in scikit-learn! I view this tree code primarily as a low-level tool that will provide a good foundation for a lot of different estimators.

@ogrisel
Copy link
Member
ogrisel commented Mar 4, 2013

Ok we agree then :)

@amueller
Copy link
Member
amueller commented Mar 4, 2013

Wow this is awesome!
I don't fully understand the speed table, though. cKDTree is the scipy implementation, right? To me it looks like that is faster than KDTree. Is that not right?

@jakevdp
Copy link
Member Author
jakevdp commented Mar 4, 2013

Andy - Yes, cKDTree is the scipy implementation. It's a bit faster for some things, much slower for others, and only implements a subset of the operations.

@jakevdp
Copy link
Member Author
jakevdp commented Mar 4, 2013

I should mention the other difference: this KDTree is picklable, and uses no dynamic memory allocation: all tree data is stored in pre-allocated numpy arrays. In cKDTree all nodes are allocated dynamically, so it can't be pickled, and comes with the other disadvantages of dynamic memory allocation. I'm not sure about how that fact affects the performance, but it seems to be a wash -- except for the building phase, which is only a small part of the total computation time. I just looked at the code, and realized there are some efficiencies that could be enabled in the KDTree node initialization. I hadn't worried about it until now because the build time is so small compared to the time of a query.

@jakevdp
Copy link
Member Author
jakevdp commented Mar 9, 2013

Just FYI - though cKDTree is faster on a few things, everything here is faster than the old Ball Tree, so it's a net win for scikit-learn.

@jakevdp
Copy link
Member Author
jakevdp commented Mar 9, 2013

@amueller - I took a closer look at why Scipy's kdtree construction is so much faster than this code. The reason is that the scipy version keeps track of an array of maxes and mins during construction, which saves an iteration over the points in each node. That approach leads to faster construction times, but there's a problem: the min and max bounds of a subnode will generally be smaller than the min and max bounds of its parent. Scipy's version doesn't allow the bounds to shrink at each node. I believe this is the reason that our kdtree outperforms theirs on query: we have tighter constraints on each node. This makes a bit of a difference for uniform data, but I anticipate a much bigger improvement when it comes to structured data -- I plan to do some detailed benchmarks as soon as I'm finished with my PyCon tutorial next week.

In any case, the build time is only a few percent of the query time, even with the slower method here. For that reason, I think we should stick with our code as-is.

I did manage to speed up the build by ~25%, though, by using raw pointers rather than typed memoryviews in some places. It turns out that if a typed memoryview is a class variable, there's some overhead associated with accessing its elements.

@amueller
Copy link
Member
amueller commented Mar 9, 2013

Thanks for the in-depth analysis Jake :) I'm a bit swamped right now (again) but I'll try to have a look and review asap!

@jaquesgrobler
Copy link
Member

Wow Jake this is very well done. Took me a while to just glance through it all.. Will try and look at it more indepth soon. Fantastic work!

@jakevdp
Copy link
Member Author
jakevdp commented Mar 14, 2013

I've been thinking more about this - one issue right now is that the dual tree approach to KDE is sub-optimal for larger atol, and explicitly cannot handle non-zero rtol. I first wrote it this way because I wasn't sure there was an efficient way to implement it, but I'm starting to think it can be done. I'll have to try some things out during the PyCon sprints!

# simultaneous_sort :
# this is a recursive quicksort implementation which sorts `distances`
# and simultaneously performs the same swaps on `indices`.
cdef void _simultaneous_sort(DTYPE_t* dist, ITYPE_t* idx, ITYPE_t size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really faster than reusing some off-the-shelf sort such as np.argsort or qsort? It looks like a rather naive quicksort (no random pivoting, median-of-three rule, tail-recursion elimination, back-off to insertion sort, ...).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no doubt that the numpy sort is superior to this algorithm. The problem is that here we need to simultaneously sort the indices and the distances. I benchmarked it against numpy using something along the lines of

i = np.argsort(dist)
ind = ind[i]
dist = dist[i]

and found that this cython version was consistently about twice as fast (mainly due to the fancy indexing). If you know of a better way to do this without resorting to cython, please let me know!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, no, I suspect that qsort will be even slower because of all the pointer indirection. But I'm still a bit worried about the naive quicksort, because there are inputs which make this take quadratic time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True - that's a good point. One piece to keep in mind, though, is that this is only sorting lists of length k, where k is the number of neighbors (typically ~1-50), so it shouldn't be a huge penalty even in the rare worst-case.

Can you think of a good route that doesn't involve re-implementing a more sophisticated sort?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually a bit surprised that this is faster than argsort, because the expected number of swaps is around 2 n lg n whereas sorting indices and then using those to sort distances would take 2 n lg n + n.

But no, I can't think of a good alternative -- there's the C++ std::sort, which is practically always implemented as the extremely fast introsort algorithm, but we also have good reasons not to use C++ code if it's not needed.

Is this called at prediction time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is called at the prediction/query, though only if the sort_results keyword is True. I think the reason this is faster is because the constant in the O[n] indexing operations is larger than the constant in front of the O[n log n] swaps. I only benchmarked it for ~1-100 neighbors, which is typical for k neighbors queries.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just double-checked the benchmarks I had run... the simultaneous sort is about 10 times faster than the numpy option for O[1] points, and about 3 times faster than the numpy option for O[1000] points.

I was thinking that if that difference is a negligible percentage of the query time, it might be cleaner just to use the numpy option. But when searching for a large number of neighbors, the cost of the current implementation is on order 30% the query cost, so it's definitely worth worrying about!

I also realized that I haven't yet implemented the sort_results keyword I mentioned above. With these benchmarks in mind, that seems like it would be a really good thing to have -- I'll take a stab at implementing that when my PyCon/PyData talks are done!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. My concern is that it may be possible to DoS a web app that calls predict on user input, but I guess we'll have to take the risk. (Median-of-three pivoting, while still O(n²), would still be nice to have as it fixes almost all non-adversarial performance problems.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Median-of-three pivoting sounds like a good idea -- I'll plan to put that in.

@larsmans
Copy link
Member

Just out of curiosity, how long does this take to compile? It's 30kLOC in C...

@jakevdp
Copy link
Member Author
jakevdp commented Mar 16, 2013

On my newish macbook pro, compilation takes ~8 seconds.

@jakevdp
Copy link
Member Author
jakevdp commented Mar 18, 2013

@larsmans - I added median-of-three pivoting, as well as using a fast direct sort for partitions less than or equal to length 3.

I also added the sort_results keyword to the query command. Thanks for all the feedback on this!

@jakevdp
Copy link
Member Author
jakevdp commented Mar 22, 2013

Dang, I just realized that my new quicksort algorithm broke something... I need to figure out how to fix it.

@jakevdp
Copy link
Member Author
jakevdp commented Mar 22, 2013

Another small addition: the ability to pass a python function defining a user metric. This is slow due to the python call overhead, but based on an email from a user today, it seems like it would be interesting to allow the option.

@jakevdp
Copy link
Member Author
jakevdp commented Mar 22, 2013

Another TODO: implement the haversine formula. This is a 2D metric which gives the angle between points on the surface of a sphere, where the positions can be specified in latitude and longitude. It should be fairly easy to implement, and useful in astronomy applications - thanks to @jsundram for the suggestion.

@larsmans
Copy link
Member

Also useful for GIS or anything using location information!

@jakevdp
Copy link
Member Author
jakevdp commented Mar 23, 2013

OK, I added the Haversine distance & associated tests, and did some cleanup (including adding the ability for our functions to pass-on exceptions). I think this metric will be a useful addition!

@jsundram
Copy link

Looks good to me. I don't have a good sense for how (in)accurate this is for (nearly) antipodal points. Do you think it may be worth checking for that, and using the cosine-based formula instead?

"maching" MatchingDistance NNEQ / N
"dice" DiceDistance NNEQ / (NTT + NNZ)
"kulsinski" KulsinskiDistance (NNEQ + N - NTT) / (NNEQ + N)
"rogerstanimoto" RogerStanimotoDistance 2 * NNEQ / (N + NNEQ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My gut feeling is that to follow scikit-learn's naming convention, this should be 'roger_stanimoto' (a similar remark applies below). However, maybe you are using these string to be compatible with another library?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually called the Rogers-Tanimoto distance (not Roger-Stanimoto).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha! Thanks @larsmans, I'll change that.

Gael, the reason I chose the string identifiers I did is because they match what's used within pdist and cdist in the scipy.spatial.distance module.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gael, the reason I chose the string identifiers I did is because they match
what's used within pdist and cdist in the scipy.spatial.distance module.

OK, sounds good. I like to have underscores separating words for
readability, but consistency with scipy is definitely a major point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other option is to allow both forms: a scikit-learn version with underscores in our documentation, but also silently accept the no-underscore version for compatibility with scipy. I do that already with the manhattan distance. Scipy calls it 'cityblock', so the code accepts that as well.

@GaelVaroquaux
Copy link
Member

Looks gorgeous! Fantastic.

Both BallTree and KDTree now have a method for performing fast
kernel density estimates in O[N log(N)] rather than the naive O[N^2].
Six common kernels are implemented, and each kernel can work with any
of the valid distance metrics.

I think that it would be good to have a high-level API to access this
functionality.

I was thinking of a new estimator (as suggested by @ogrisel), but using
the 'score' method, rather than the predict_proba one. IMHO, the score
method is pretty much suitable for a general density estimation API, and
it is already partly used like this in the scikit (in the PCA or in the
covariance matrix estimators).

• Both BallTree and KDTree have a method for computing fast 2-point
correlation functions, also in O[N log(N)] rather than the naive O[N^2].
This works with any of the valid distance metrics.

As a bonus, there is a set of class interfaces to the various
distance metrics, with functions to generate a pairwise distance
matrix under any of the metrics.

Awesome. Is there anyway that these could be integrated more closely with
the sklearn.metrics.pairwise module? Also, it seems to me that the work
you have done should give speedups for the construction of a
kneighbors_graph and a radius_neighbors_graph. Am I wrong? I see no
modification of these functions, so I am wondering.

MISC: remove unused imports
@GaelVaroquaux
Copy link
Member

OK, @jakevdp , I want this guy in: it has been lying around for too long, and is ready apart from the empty lines issue. I am finishing it off and merging.

@GaelVaroquaux
Copy link
Member

Merged! (I didn't rebase, because it was creating a conflict hell).

Bravo Jake.

@jakevdp
Copy link
Member Author
jakevdp commented Jul 14, 2013

Awesome - thanks!
Any idea why this says there are unmerged commits? Is it just that PR merge commit from you yesterday?

@GaelVaroquaux
Copy link
Member

Awesome - thanks!

Not that awesome, jenkins is failing on us:
http://jenkins.shiningpanda.com/scikit-learn/job/python-2.6-numpy-1.3.0-scipy-0.7.2/lastFailedBuild/console
The tests ran fine on my box :(.

Can you have a look at that, I am waiting to get in a plane and could be
offline any time.

Any idea why this says there are unmerged commits? Is it just that PR merge
commit from you yesterday?

Yes, I short cut it to avoid the merge commit.

@jakevdp
Copy link
Member Author
jakevdp commented Jul 14, 2013

Is Jenkins different than Travis? The Travis build was doing just fine.

@jakevdp
Copy link
Member Author
jakevdp commented Jul 14, 2013

I don't understand some of these failures... in particular

ValueError: Buffer dtype mismatch, expected 'ITYPE_t' but got Python object in 'NodeData_t.idx_start'

NodeData_t.idx_start is cdef'd as type ITYPE_t, and I'm not sure how it could ever be interpreted as a Python object. I'm going to see if I can figure it out.

@jakevdp
Copy link
Member Author
jakevdp commented Jul 14, 2013

I think that all the problems can be traced to this:

# use a dummy variable to determine the python data type
cdef NodeData_t dummy
cdef NodeData_t[:] dummy_view = <NodeData_t[:1]> &dummy
NodeData = np.asarray(dummy_view).dtype

Recall that NodeData_t is a C structure:

cdef struct NodeData_t:
    ITYPE_t idx_start
    ITYPE_t idx_end
    int is_leaf
    DTYPE_t radius

I think that on the Jenkins machine, the resulting NodeData dtype does not match the C type denoted by NodeData_t. I've been using this sort of hack for a long time, and never seen it fail.

Anybody have better ideas about how to take a cdef'd struct and extract an equivalent numpy dtype? Last year I sent this snippet to the cython list, asking if anyone knew a better way, and there wasn't a response.

@larsmans
Copy link
Member

Why do you need such a hack at all? Can't you just malloc a NodeData_t array?

@jakevdp
Copy link
Member Author
jakevdp commented Jul 15, 2013

Why do you need such a hack at all? Can't you just malloc a NodeData_t array?

I'd prefer the memory management to be handled by numpy -- it seems cleaner that way. I don't use malloc anywhere in this code, and consider that a positive feature.

@ogrisel
Copy link
Member
ogrisel commented Jul 15, 2013

Note that the failure is at line:

self.node_data = np.empty(0, dtype=NodeData, order='C')

This only fails on the python 2.6 / numpy 1.3. So it's probably related to the numpy version.

@ogrisel
Copy link
Member
ogrisel commented Jul 15, 2013

I tried to install numpy 1.3 on python 2.7 but the numpy build fails. Maybe we should stop trying to support that old a version of numpy.

Less than 5% of scipy stack users who answered this survey earlier this year use numpy 1.3:

http://astrofrog.github.io/blog/2013/01/13/what-python-installations-are-scientists-using/

The most recent Ubuntu LTS (which is precise pangolin 12.04) has numpy version 1.6.1 by default.

@GaelVaroquaux what is the numpy version in your institute? I will try to build with numpy 1.5.

@larsmans
Copy link
Member

Doesn't Cython views have some kind of magic memory allocation that we can use? I understand the fear of malloc, but this doesn't look very clean either.

@jakevdp
Copy link
Member Author
jakevdp commented Jul 15, 2013

It's more than just a fear of malloc -- using numpy arrays makes serialization extremely straightforward as well. I'm not sure about cython's memory allocation. Given that this fails on numpy 1.3, I wouldn't be surprised if the same error would occur when we convert the cython typed memoryview to a numpy array when we serialize the object using __getstate__ and __setstate__.

@jakevdp
Copy link
Member Author
jakevdp commented Jul 15, 2013

Actually, it looks like numpy 1.4 and below does not support the buffer interface, so almost nothing in this PR would work under those numpy versions:

http://comments.gmane.org/gmane.comp.python.cython.devel/14937

I didn't realize that detail... so I guess our choice is to either scrap this PR in full, or to drop numpy 1.3-1.4 support (for reference, numpy 1.5 was released in August 2010). Another option would be to rewrite all of this using the deprecated cython numpy array syntax, but that would feel like going backward to me.

If we deem that supporting old numpy versions is important, it's not all lost: I'd release all of this functionality as its own mini-package for the time being, and we could revisit its inclusion once the numpy requirement is upgraded to 1.5.

@ogrisel
Copy link
Member
ogrisel commented Jul 15, 2013

Maybe we should ask for users opinions on the mailing list:

  • which version of numpy do you use?
  • would you agree to raise the minimal version for numpy to 1.5?

@larsmans
Copy link
Member

FYI, the most recent version of CentOS still ships NumPy 1.4. That's probably true of Red Hat Enterprise Linux and its other clones as well.

@ogrisel
Copy link
Member
ogrisel commented Jul 15, 2013

That's bad news.

@ogrisel
Copy link
Member
ogrisel commented Jul 15, 2013

Another option would be to rewrite all of this using the deprecated cython numpy array syntax, but that would feel like going backward to me.

I agree, but on the other end this is for backward compat. Any difference in runtime performance?

@jakevdp
Copy link
Member Author
jakevdp commented Jul 15, 2013

I've sent a message to the cython users list: hopefully that will generate some ideas.

@GaelVaroquaux
Copy link
Member

I'd prefer the memory management to be handled by numpy -- it seems cleaner
that way. I don't use malloc anywhere in this code, and consider that a
positive feature.

I am usually happier if I can avoid malloc/free.

@ogrisel
Copy link
Member
ogrisel commented Jul 16, 2013

But memory allocation here is not the problem or is it? It seems that the problem is the fact that the memory view feature of cython does not work with old versions of NumPy (1.3 and 1.4) but only works with the NumPy version that follow the buffer interface (from NumPy 1.5 and on).

So the real question is: do we bump up the requirement for sklearn users to have NumPy 1.5+ or do we rewrite this the ball / kd-trees to use the old, deprecated cython-numpy array syntax instead of using the memory views syntax?

@jakevdp
Copy link
Member Author
jakevdp commented Jul 16, 2013

I think I've come up with a solution that will work without too much rewriting of the code. We can use a cython utility function that looks like this:

cimport numpy as np

def array_to_memview(np.ndarray[double, ndim=1, mode='c'] X):
    cdef double* Xdata = <double*> X.data
    cdef double[::1] Xmem = <double[:X.shape[0]]> Xdata
    return Xmem

All external-facing functions would have to be changed so as to explicitly take numpy arrays as arguments, and then perform these gymnastics to convert the input. As long as we maintain references to the original object for as long as the memoryview is needed, we shouldn't run into any memory expiration issues.

One complication: as there is no templating, I think we'd need to explicitly define a different conversion function like the one above for each combination of data type and number of dimensions. The result is going to be code that's pretty messy and confusing to read, but it should be backward compatible to numpy 1.3.

@ogrisel
Copy link
Member
ogrisel commented Jul 16, 2013

I think that worth trying it. Just a cosmetics note: Xdata => X_data and so on.

@GaelVaroquaux
Copy link
Member

I am 👍 on that option. I think that dropping 1.3 compatibility would
have been an option, but dropping 1.4 is a bit violent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0